diff --git a/HISTORY.rst b/HISTORY.rst
index 21e6d4bb..764b602d 100644
--- a/HISTORY.rst
+++ b/HISTORY.rst
@@ -7,13 +7,44 @@ This project adheres to `Semantic Versioning `_.
vNext
===========================
-* `#660`_: Userspace connection migration. This should be an invisible change, but represents a significant refactor of how connections function.
-Instead of a weird blend of DatabaseConnection models and underlying Django models (which were the original Explorer connections),
-this migrates all connections to DatabaseConnection models and implements proper foreign keys to them on the Query and QueryLog models.
-A data migration creates new DatabaseConnection models based on the configured settings.EXPLORER_CONNECTIONS.
-Going forward, admins can create new Django-backed DatabaseConnection models by registering the connection in EXPLORER_CONNECTIONS, and then creating a
-DatabaseConnection model using the Django admin or the user-facing /connections/new/ form, and entering the Django DB alias and setting the connection type to "Django Connection"
-
+* `#664`_: Improvements to the AI SQL Assistant:
+
+ - Table Annotations: Write persistent table annotations with descriptive information that will get injected into the
+ prompt for the assistant. For example, if a table is commonly joined to another table through a non-obvious foreign
+ key, you can tell the assistant about it in plain english, as an annotation to that table. Every time that table is
+ deemed 'relevant' to an assistant request, that annotation will be included alongside the schema and sample data.
+ - Few-Shot Examples: Using the small checkbox on the bottom-right of any saved queries, you can designate certain
+ queries as 'few shot examples". When making an assistant request, any designated few-shot examples that reference
+ the same tables as your assistant request will get included as 'reference sql' in the prompt for the LLM.
+ - Autocomplete / multiselect when selecting tables info to send to the SQL Assistant. Much easier and more keyboard
+ focused.
+ - Relevant tables are added client-side visually, in real time, based on what's in the SQL editor and/or any tables
+ mentioned in the assistant request. The dependency on sql_metadata is therefore removed, as server-side SQL parsing
+ is no longer necessary.
+ - Ability to view Assistant request/response history.
+ - Improved system prompt that emphasizes the particular SQL dialect being used.
+ - Addresses issue #657.
+
+* `#660`_: Userspace connection migration.
+
+ - This should be an invisible change, but represents a significant refactor of how connections function. Instead of a
+ weird blend of DatabaseConnection models and underlying Django models (which were the original Explorer
+ connections), this migrates all connections to DatabaseConnection models and implements proper foreign keys to them
+ on the Query and QueryLog models. A data migration creates new DatabaseConnection models based on the configured
+ settings.EXPLORER_CONNECTIONS. Going forward, admins can create new Django-backed DatabaseConnection models by
+ registering the connection in EXPLORER_CONNECTIONS, and then creating a DatabaseConnection model using the Django
+ admin or the user-facing /connections/new/ form, and entering the Django DB alias and setting the connection type
+ to "Django Connection".
+ - The Query.connection and QueryLog.connection fields are deprecated and will be removed in a future release. They
+ are kept around in this release in case there is an unforeseen issue with the migration. Preserving the fields for
+ now ensures there is no data loss in the event that a rollback to an earlier version is required.
+
+* Fixed a bug when validating connections to uploaded files. Also added basic locking when downloading files from S3.
+
+* Keyboard shortcut for formatting the SQL in the editor.
+
+ - Cmd+Shift+F (Windows: Ctrl+Shift+F)
+ - The format button has been moved tobe a small icon towards the bottom-right of the SQL editor.
`5.2.0`_ (2024-08-19)
===========================
@@ -643,6 +674,8 @@ Initial Release
.. _#651: https://github.com/explorerhq/sql-explorer/pull/651
.. _#659: https://github.com/explorerhq/sql-explorer/pull/659
.. _#662: https://github.com/explorerhq/sql-explorer/pull/662
+.. _#660: https://github.com/explorerhq/sql-explorer/pull/660
+.. _#664: https://github.com/explorerhq/sql-explorer/pull/664
.. _#269: https://github.com/explorerhq/sql-explorer/issues/269
.. _#288: https://github.com/explorerhq/sql-explorer/issues/288
diff --git a/docs/features.rst b/docs/features.rst
index a93e9b39..c9dd8f42 100644
--- a/docs/features.rst
+++ b/docs/features.rst
@@ -5,7 +5,19 @@ SQL Assistant
-------------
- Built in integration with OpenAI (or the LLM of your choosing)
to quickly get help with your query, with relevant schema
- automatically injected into the prompt. Simple, effective.
+ automatically injected into the prompt.
+- The assistant tries hard to get relevant context into the prompt to the LLM, alongside your explicit request. You
+ can choose tables to include explicitly (and any tables you are reference in your SQL you will see get included as
+ well). When a table is "included", the prompt will include the schema of the table, 3 sample rows, any Table
+ Annotations you have added, and any designated "few shot examples". More on each of those below.
+- Table Annotations: Write persistent table annotations with descriptive information that will get injected into the
+ prompt for the assistant. For example, if a table is commonly joined to another table through a non-obvious foreign
+ key, you can tell the assistant about it in plain english, as an annotation to that table. Every time that table is
+ deemed 'relevant' to an assistant request, that annotation will be included alongside the schema and sample data.
+- Few-shot examples: Using the small checkbox on the bottom-right of any saved query, you can designate queries as
+ "Assistant Examples". When making an assistant request, the 'included tables' are intersected with tables referenced
+ by designated Example queries, and those queries are injected into the prompt, and the LLM is told that that these
+ are good reference queries.
Database Support
----------------
@@ -222,8 +234,7 @@ Power tips
view.
- Command+Enter and Ctrl+Enter will execute a query when typing in
the SQL editor area.
-- Hit the "Format" button to format and clean up your SQL (this is
- non-validating -- just formatting).
+- Cmd+Shift+F (Windows: Ctrl+Shift+F) to format the SQL in the editor.
- Use the Query Logs feature to share one-time queries that aren't
worth creating a persistent query for. Just run your SQL in the
playground, then navigate to ``/logs`` and share the link
diff --git a/explorer/admin.py b/explorer/admin.py
index 0cad25ac..219c87ce 100644
--- a/explorer/admin.py
+++ b/explorer/admin.py
@@ -7,7 +7,7 @@
@admin.register(Query)
class QueryAdmin(admin.ModelAdmin):
- list_display = ("title", "description", "created_by_user",)
+ list_display = ("title", "description", "created_by_user", "few_shot")
list_filter = ("title",)
raw_id_fields = ("created_by_user",)
actions = [generate_report_action()]
diff --git a/explorer/app_settings.py b/explorer/app_settings.py
index a93257b3..8bf69b63 100644
--- a/explorer/app_settings.py
+++ b/explorer/app_settings.py
@@ -152,11 +152,17 @@
EXPLORER_AI_API_KEY = getattr(settings, "EXPLORER_AI_API_KEY", None)
EXPLORER_ASSISTANT_BASE_URL = getattr(settings, "EXPLORER_ASSISTANT_BASE_URL", "https://api.openai.com/v1")
+
+# Deprecated. Will be removed in a future release. Please use EXPLORER_ASSISTANT_MODEL_NAME instead
EXPLORER_ASSISTANT_MODEL = getattr(settings, "EXPLORER_ASSISTANT_MODEL",
# Return the model name and max_tokens it supports
{"name": "gpt-4o",
"max_tokens": 128000})
+EXPLORER_ASSISTANT_MODEL_NAME = getattr(settings, "EXPLORER_ASSISTANT_MODEL_NAME",
+ EXPLORER_ASSISTANT_MODEL["name"])
+
+
EXPLORER_DB_CONNECTIONS_ENABLED = getattr(settings, "EXPLORER_DB_CONNECTIONS_ENABLED", False)
EXPLORER_USER_UPLOADS_ENABLED = getattr(settings, "EXPLORER_USER_UPLOADS_ENABLED", False)
EXPLORER_PRUNE_LOCAL_UPLOAD_COPY_DAYS_INACTIVITY = getattr(settings,
diff --git a/explorer/assistant/forms.py b/explorer/assistant/forms.py
new file mode 100644
index 00000000..ddcaffea
--- /dev/null
+++ b/explorer/assistant/forms.py
@@ -0,0 +1,42 @@
+from django import forms
+from explorer.assistant.models import TableDescription
+from explorer.ee.db_connections.utils import default_db_connection
+
+
+class TableDescriptionForm(forms.ModelForm):
+ class Meta:
+ model = TableDescription
+ fields = "__all__"
+ widgets = {
+ "database_connection": forms.Select(attrs={"class": "form-select"}),
+ "description": forms.Textarea(attrs={"class": "form-control", "rows": 3}),
+ }
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if not self.instance.pk: # Check if this is a new instance
+ # Set the default value for database_connection
+ self.fields["database_connection"].initial = default_db_connection()
+
+ if self.instance and self.instance.table_name:
+ choices = [(self.instance.table_name, self.instance.table_name)]
+ else:
+ choices = []
+
+ f = forms.ChoiceField(
+ choices=choices,
+ widget=forms.Select(attrs={"class": "form-select", "data-placeholder": "Select table"})
+ )
+
+ # We don't actually care about validating the 'choices' that the ChoiceField does by default.
+ # Really we are just using that field type in order to get a valid pre-populated Select widget on the client
+ # But also it can't be blank!
+ def valid_value_new(v):
+ return bool(v)
+
+ f.valid_value = valid_value_new
+
+ self.fields["table_name"] = f
+
+ if self.instance and self.instance.table_name:
+ self.fields["table_name"].initial = self.instance.table_name
diff --git a/explorer/assistant/models.py b/explorer/assistant/models.py
index 21c6db6d..556148c1 100644
--- a/explorer/assistant/models.py
+++ b/explorer/assistant/models.py
@@ -1,5 +1,6 @@
from django.db import models
from django.conf import settings
+from explorer.ee.db_connections.models import DatabaseConnection
class PromptLog(models.Model):
@@ -8,6 +9,7 @@ class Meta:
app_label = "explorer"
prompt = models.TextField(blank=True)
+ user_request = models.TextField(blank=True)
response = models.TextField(blank=True)
run_by_user = models.ForeignKey(
settings.AUTH_USER_MODEL,
@@ -19,3 +21,18 @@ class Meta:
duration = models.FloatField(blank=True, null=True) # seconds
model = models.CharField(blank=True, max_length=128, default="")
error = models.TextField(blank=True, null=True)
+ database_connection = models.ForeignKey(to=DatabaseConnection, on_delete=models.SET_NULL, blank=True, null=True)
+
+
+class TableDescription(models.Model):
+
+ class Meta:
+ app_label = "explorer"
+ unique_together = ("database_connection", "table_name")
+
+ database_connection = models.ForeignKey(to=DatabaseConnection, on_delete=models.CASCADE)
+ table_name = models.CharField(max_length=512)
+ description = models.TextField()
+
+ def __str__(self):
+ return f"{self.database_connection.alias} - {self.table_name}"
diff --git a/explorer/assistant/urls.py b/explorer/assistant/urls.py
new file mode 100644
index 00000000..bf6e479b
--- /dev/null
+++ b/explorer/assistant/urls.py
@@ -0,0 +1,16 @@
+from django.urls import path
+from explorer.assistant.views import (TableDescriptionListView,
+ TableDescriptionCreateView,
+ TableDescriptionUpdateView,
+ TableDescriptionDeleteView,
+ AssistantHelpView,
+ AssistantHistoryApiView)
+
+assistant_urls = [
+ path("assistant/", AssistantHelpView.as_view(), name="assistant"),
+ path("assistant/history/", AssistantHistoryApiView.as_view(), name="assistant_history"),
+ path("table-descriptions/", TableDescriptionListView.as_view(), name="table_description_list"),
+ path("table-descriptions/new/", TableDescriptionCreateView.as_view(), name="table_description_create"),
+ path("table-descriptions//update/", TableDescriptionUpdateView.as_view(), name="table_description_update"),
+ path("table-descriptions//delete/", TableDescriptionDeleteView.as_view(), name="table_description_delete"),
+]
diff --git a/explorer/assistant/utils.py b/explorer/assistant/utils.py
index 3dd3cdfb..0a962406 100644
--- a/explorer/assistant/utils.py
+++ b/explorer/assistant/utils.py
@@ -1,12 +1,16 @@
+from dataclasses import dataclass
from explorer import app_settings
from explorer.schema import schema_info
-from explorer.models import ExplorerValue
+from explorer.models import ExplorerValue, Query
from django.db.utils import OperationalError
+from django.db.models.functions import Lower
+from django.db.models import Q
+from explorer.assistant.models import TableDescription
OPENAI_MODEL = app_settings.EXPLORER_ASSISTANT_MODEL["name"]
-ROW_SAMPLE_SIZE = 2
-MAX_FIELD_SAMPLE_SIZE = 500 # characters
+ROW_SAMPLE_SIZE = 3
+MAX_FIELD_SAMPLE_SIZE = 200 # characters
def openai_client():
@@ -34,25 +38,17 @@ def extract_response(r):
return r[-1].content
-def tables_from_schema_info(db_connection, table_names):
+def table_schema(db_connection, table_name):
schema = schema_info(db_connection)
- return [table for table in schema if table[0] in set(table_names)]
-
-
-def sample_rows_from_tables(connection, table_names):
- ret = ""
- for table_name in table_names:
- ret += f"SAMPLE FROM TABLE {table_name}:\n"
- ret += format_rows_from_table(
- sample_rows_from_table(connection, table_name)
- ) + "\n\n"
- return ret
+ s = [table for table in schema if table[0] == table_name]
+ if len(s):
+ return s[0][1]
def sample_rows_from_table(connection, table_name):
"""
Fetches a sample of rows from the specified table and ensures that any field values
- exceeding 500 characters (or bytes) are truncated. This is useful for handling fields
+ exceeding 200 characters (or bytes) are truncated. This is useful for handling fields
like "description" that might contain very long strings of text or binary data.
Truncating these fields prevents issues with displaying or processing overly large values.
An ellipsis ("...") is appended to indicate that the data has been truncated.
@@ -76,8 +72,8 @@ def sample_rows_from_table(connection, table_name):
new_val = field
if isinstance(field, str) and len(field) > MAX_FIELD_SAMPLE_SIZE:
new_val = field[:MAX_FIELD_SAMPLE_SIZE] + "..." # Truncate and add ellipsis
- elif isinstance(field, (bytes, bytearray)) and len(field) > MAX_FIELD_SAMPLE_SIZE:
- new_val = field[:MAX_FIELD_SAMPLE_SIZE] + b"..." # Truncate binary data
+ elif isinstance(field, (bytes, bytearray)):
+ new_val = ""
processed_row.append(new_val)
ret.append(processed_row)
@@ -87,66 +83,97 @@ def sample_rows_from_table(connection, table_name):
def format_rows_from_table(rows):
- column_headers = list(rows[0])
- ret = " | ".join(column_headers) + "\n" + "-" * 50 + "\n"
- for row in rows[1:]:
- row_str = " | ".join(str(item) for item in row)
- ret += row_str + "\n"
- return ret
-
-
-def get_table_names_from_query(sql):
- from sql_metadata import Parser
- if sql:
- try:
- parsed = Parser(sql)
- return parsed.tables
- except ValueError:
- return []
- return []
-
-
-def num_tokens_from_string(string: str) -> int:
- """Returns the number of tokens in a text string."""
- import tiktoken
- try:
- encoding = tiktoken.encoding_for_model(OPENAI_MODEL)
- except KeyError:
- encoding = tiktoken.get_encoding("cl100k_base")
- num_tokens = len(encoding.encode(string))
- return num_tokens
+ """ Given an array of rows (a list of lists), returns e.g.
+AlbumId | Title | ArtistId
+1 | For Those About To Rock We Salute You | 1
+2 | Let It Rip | 2
+3 | Restless and Wild | 2
-def fits_in_window(string: str) -> bool:
- # Ratchet down by 5% to account for other boilerplate and system prompt
- # TODO make this better by actually looking at the token count of the system prompt
- return num_tokens_from_string(string) < (app_settings.EXPLORER_ASSISTANT_MODEL["max_tokens"] * 0.95)
+ """
+ return "\n".join([" | ".join([str(item) for item in row]) for row in rows])
-def build_prompt(db_connection, assistant_request, included_tables, query_error=None, sql=None):
- user_prompt = ""
- djc = db_connection.as_django_connection()
- user_prompt += f"## Database Vendor / SQL Flavor is {djc.vendor}\n\n"
+def build_system_prompt(flavor):
+ bsp = ExplorerValue.objects.get_item(ExplorerValue.ASSISTANT_SYSTEM_PROMPT).value
+ bsp += f"\nYou are an expert at writing SQL, specifically for {flavor}, and account for the nuances of this dialect of SQL. You always respond with valid {flavor} SQL." # noqa
+ return bsp
+
+
+def get_relevant_annotation(db_connection, t):
+ return TableDescription.objects.annotate(
+ table_name_lower=Lower("table_name")
+ ).filter(
+ database_connection=db_connection,
+ table_name_lower=t.lower()
+ ).first()
+
+
+def get_relevant_few_shots(db_connection, included_tables):
+ included_tables_lower = [t.lower() for t in included_tables]
+
+ query_conditions = Q()
+ for table in included_tables_lower:
+ query_conditions |= Q(sql__icontains=table)
- if query_error:
- user_prompt += f"## Query Error ##\n\n{query_error}\n\n"
+ return Query.objects.annotate(
+ sql_lower=Lower("sql")
+ ).filter(
+ database_connection=db_connection,
+ few_shot=True
+ ).filter(query_conditions)
- if sql:
- user_prompt += f"## Existing SQL ##\n\n{sql}\n\n"
- results_sample = sample_rows_from_tables(djc,
- included_tables)
- if fits_in_window(user_prompt + results_sample):
- user_prompt += f"## Table Structure with Sampled Data ##\n\n{results_sample}\n\n"
- else: # If it's too large with sampling, then provide *just* the structure
- table_struct = tables_from_schema_info(db_connection,
- included_tables)
- user_prompt += f"## Table Structure ##\n\n{table_struct}\n\n"
+def get_few_shot_chunk(db_connection, included_tables):
+ included_tables = [t.lower() for t in included_tables]
+ few_shot_examples = get_relevant_few_shots(db_connection, included_tables)
+ if few_shot_examples:
+ return "## Relevant example queries, written by expert SQL analysts ##\n" + "\n\n".join(
+ [f"Description: {fs.title} - {fs.description}\nSQL:\n{fs.sql}"
+ for fs in few_shot_examples.all()]
+ )
+
+
+@dataclass
+class TablePromptData:
+ name: str
+ schema: list
+ sample: list
+ annotation: TableDescription
+
+ def render(self):
+ fmt_schema = "\n".join([str(field) for field in self.schema])
+ ret = f"""## Information for Table '{self.name}' ##
+
+Schema:\n{fmt_schema}
+
+Sample rows:\n{format_rows_from_table(self.sample)}"""
+ if self.annotation:
+ ret += f"\nUsage Notes:\n{self.annotation.description}"
+ return ret
+
+
+def build_prompt(db_connection, assistant_request, included_tables, query_error=None, sql=None):
+ included_tables = [t.lower() for t in included_tables]
+
+ error_chunk = f"## Query Error ##\n{query_error}" if query_error else None
+ sql_chunk = f"## Existing User-Written SQL ##\n{sql}" if sql else None
+ request_chunk = f"## User's Request to Assistant ##\n{assistant_request}"
+ table_chunks = [
+ TablePromptData(
+ name=t,
+ schema=table_schema(db_connection, t),
+ sample=sample_rows_from_table(db_connection.as_django_connection(), t),
+ annotation=get_relevant_annotation(db_connection, t)
+ ).render()
+ for t in included_tables
+ ]
+ few_shot_chunk = get_few_shot_chunk(db_connection, included_tables)
- user_prompt += f"## User's Request to Assistant ##\n\n{assistant_request}\n\n"
+ chunks = [error_chunk, sql_chunk, *table_chunks, few_shot_chunk, request_chunk]
prompt = {
- "system": ExplorerValue.objects.get_item(ExplorerValue.ASSISTANT_SYSTEM_PROMPT).value,
- "user": user_prompt
+ "system": build_system_prompt(db_connection.as_django_connection().vendor),
+ "user": "\n\n".join([c for c in chunks if c]),
}
return prompt
diff --git a/explorer/assistant/views.py b/explorer/assistant/views.py
index f5d6f2aa..2571e216 100644
--- a/explorer/assistant/views.py
+++ b/explorer/assistant/views.py
@@ -1,14 +1,21 @@
from django.http import JsonResponse
from django.views import View
from django.utils import timezone
+from django.views.generic import ListView, CreateView, UpdateView, DeleteView
+from django.urls import reverse_lazy
+
+from .forms import TableDescriptionForm
+from .models import TableDescription
+
import json
+from explorer.views.auth import PermissionRequiredMixin
+from explorer.views.mixins import ExplorerContextMixin
from explorer.telemetry import Stat, StatNames
from explorer.ee.db_connections.models import DatabaseConnection
from explorer.assistant.models import PromptLog
from explorer.assistant.utils import (
do_req, extract_response,
- get_table_names_from_query,
build_prompt
)
@@ -16,16 +23,15 @@
def run_assistant(request_data, user):
sql = request_data.get("sql")
- extra_tables = request_data.get("selected_tables", [])
- included_tables = get_table_names_from_query(sql) + extra_tables
+ included_tables = request_data.get("selected_tables", [])
connection_id = request_data.get("connection_id")
try:
conn = DatabaseConnection.objects.get(id=connection_id)
except DatabaseConnection.DoesNotExist:
return "Error: Connection not found"
-
- prompt = build_prompt(conn, request_data.get("assistant_request"),
+ assistant_request = request_data.get("assistant_request")
+ prompt = build_prompt(conn, assistant_request,
included_tables, request_data.get("db_error"), request_data.get("sql"))
start = timezone.now()
@@ -33,6 +39,8 @@ def run_assistant(request_data, user):
prompt=prompt,
run_by_user=user,
run_at=timezone.now(),
+ user_request=assistant_request,
+ database_connection=conn
)
response_text = None
try:
@@ -47,7 +55,6 @@ def run_assistant(request_data, user):
pl.save()
Stat(StatNames.ASSISTANT_RUN, {
"included_table_count": len(included_tables),
- "extra_table_count": len(extra_tables),
"has_sql": bool(sql),
"duration": pl.duration,
}).track()
@@ -67,3 +74,51 @@ def post(self, request, *args, **kwargs):
return JsonResponse(response_data)
except json.JSONDecodeError:
return JsonResponse({"status": "error", "message": "Invalid JSON"}, status=400)
+
+
+class TableDescriptionListView(PermissionRequiredMixin, ExplorerContextMixin, ListView):
+ model = TableDescription
+ permission_required = "view_permission"
+ template_name = "assistant/table_description_list.html"
+ context_object_name = "table_descriptions"
+
+
+class TableDescriptionCreateView(PermissionRequiredMixin, ExplorerContextMixin, CreateView):
+ model = TableDescription
+ permission_required = "change_permission"
+ template_name = "assistant/table_description_form.html"
+ success_url = reverse_lazy("table_description_list")
+ form_class = TableDescriptionForm
+
+
+class TableDescriptionUpdateView(PermissionRequiredMixin, ExplorerContextMixin, UpdateView):
+ model = TableDescription
+ permission_required = "change_permission"
+ template_name = "assistant/table_description_form.html"
+ success_url = reverse_lazy("table_description_list")
+ form_class = TableDescriptionForm
+
+
+class TableDescriptionDeleteView(PermissionRequiredMixin, ExplorerContextMixin, DeleteView):
+ model = TableDescription
+ permission_required = "change_permission"
+ template_name = "assistant/table_description_confirm_delete.html"
+ success_url = reverse_lazy("table_description_list")
+
+
+class AssistantHistoryApiView(View):
+
+ def post(self, request, *args, **kwargs):
+ try:
+ data = json.loads(request.body)
+ logs = PromptLog.objects.filter(
+ run_by_user=request.user,
+ database_connection_id=data["connection_id"]
+ ).order_by("-run_at")[:5]
+ ret = [{
+ "user_request": log.user_request,
+ "response": log.response
+ } for log in logs]
+ return JsonResponse({"logs": ret})
+ except json.JSONDecodeError:
+ return JsonResponse({"status": "error", "message": "Invalid JSON"}, status=400)
diff --git a/explorer/ee/db_connections/models.py b/explorer/ee/db_connections/models.py
index ea81a8f8..140543ac 100644
--- a/explorer/ee/db_connections/models.py
+++ b/explorer/ee/db_connections/models.py
@@ -4,7 +4,7 @@
from django.db.utils import load_backend
from explorer.app_settings import EXPLORER_CONNECTIONS
from explorer.ee.db_connections.utils import quick_hash, uploaded_db_local_path
-
+from django.core.cache import cache
from django_cryptography.fields import encrypt
@@ -66,12 +66,27 @@ def _download_sqlite(self):
s3 = get_s3_bucket()
s3.download_file(self.host, self.local_name)
+ def _download_needed(self):
+ # If the file doesn't exist, obviously we need to download it
+ # If it does exist, then check if it's out of date. But only check if in fact the DatabaseConnection has been
+ # saved to the DB. For example, we might be validating an unsaved connection, in which case the fingerprint
+ # won't be set yet.
+ return (not os.path.exists(self.local_name) or
+ (self.id is not None and self.local_fingerprint() != self.upload_fingerprint))
+
def download_sqlite_if_needed(self):
- download = not os.path.exists(self.local_name) or self.local_fingerprint() != self.upload_fingerprint
- if download:
- self._download_sqlite()
- self.update_fingerprint()
+ if self._download_needed():
+ cache_key = f"download_lock_{self.local_name}"
+ lock_acquired = cache.add(cache_key, "locked", timeout=300) # Timeout after 5 minutes
+
+ if lock_acquired:
+ try:
+ if self._download_needed():
+ self._download_sqlite()
+ self.update_fingerprint()
+ finally:
+ cache.delete(cache_key)
@property
def is_upload(self):
diff --git a/explorer/ee/urls.py b/explorer/ee/urls.py
index 0359d628..e1bfdc92 100644
--- a/explorer/ee/urls.py
+++ b/explorer/ee/urls.py
@@ -20,12 +20,7 @@
path("connections/create_upload/", DatabaseConnectionUploadCreateView.as_view(), name="explorer_upload_create"),
path("connections//edit/", DatabaseConnectionUpdateView.as_view(), name="explorer_connection_update"),
path("connections//delete/", DatabaseConnectionDeleteView.as_view(), name="explorer_connection_delete"),
- # There are two URLs here because the form can call validate from /connections/new/ or from /connections//edit/
- # which have different relative paths. It's easier to just provide both of these URLs rather than deal with this
- # client-side.
path("connections/validate/", DatabaseConnectionValidateView.as_view(), name="explorer_connection_validate"),
- path("connections//validate/", DatabaseConnectionValidateView.as_view(),
- name="explorer_connection_validate_with_pk"),
path("connections//refresh/", DatabaseConnectionRefreshView.as_view(),
name="explorer_connection_refresh")
]
diff --git a/explorer/forms.py b/explorer/forms.py
index fb6511c2..e025b558 100644
--- a/explorer/forms.py
+++ b/explorer/forms.py
@@ -34,6 +34,7 @@ class QueryForm(ModelForm):
sql = SqlField()
snapshot = BooleanField(widget=CheckboxInput, required=False)
+ few_shot = BooleanField(widget=CheckboxInput, required=False)
database_connection = CharField(widget=Select, required=False)
def __init__(self, *args, **kwargs):
@@ -77,4 +78,4 @@ def connections(self):
class Meta:
model = Query
- fields = ["title", "sql", "description", "snapshot", "database_connection"]
+ fields = ["title", "sql", "description", "snapshot", "database_connection", "few_shot"]
diff --git a/explorer/migrations/0026_tabledescription.py b/explorer/migrations/0026_tabledescription.py
new file mode 100644
index 00000000..782ed1bc
--- /dev/null
+++ b/explorer/migrations/0026_tabledescription.py
@@ -0,0 +1,26 @@
+# Generated by Django 5.0.4 on 2024-08-22 01:24
+
+import django.db.models.deletion
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('explorer', '0025_alter_query_database_connection_alter_querylog_database_connection'),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name='TableDescription',
+ fields=[
+ ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
+ ('table_name', models.CharField(max_length=512)),
+ ('description', models.TextField()),
+ ('database_connection', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='explorer.databaseconnection')),
+ ],
+ options={
+ 'unique_together': {('database_connection', 'table_name')},
+ },
+ ),
+ ]
diff --git a/explorer/migrations/0027_query_few_shot.py b/explorer/migrations/0027_query_few_shot.py
new file mode 100644
index 00000000..18e901fb
--- /dev/null
+++ b/explorer/migrations/0027_query_few_shot.py
@@ -0,0 +1,18 @@
+# Generated by Django 5.0.4 on 2024-08-25 21:26
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('explorer', '0026_tabledescription'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='query',
+ name='few_shot',
+ field=models.BooleanField(default=False, help_text='Will be included as a good example of SQL in assistant queries that use relevant tables'),
+ ),
+ ]
diff --git a/explorer/migrations/0028_promptlog_database_connection_promptlog_user_request.py b/explorer/migrations/0028_promptlog_database_connection_promptlog_user_request.py
new file mode 100644
index 00000000..dfba7809
--- /dev/null
+++ b/explorer/migrations/0028_promptlog_database_connection_promptlog_user_request.py
@@ -0,0 +1,24 @@
+# Generated by Django 5.0.4 on 2024-08-27 18:59
+
+import django.db.models.deletion
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('explorer', '0027_query_few_shot'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='promptlog',
+ name='database_connection',
+ field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='explorer.databaseconnection'),
+ ),
+ migrations.AddField(
+ model_name='promptlog',
+ name='user_request',
+ field=models.TextField(blank=True),
+ ),
+ ]
diff --git a/explorer/models.py b/explorer/models.py
index 9fe8ef11..67939df8 100644
--- a/explorer/models.py
+++ b/explorer/models.py
@@ -9,8 +9,6 @@
from django.utils.translation import gettext_lazy as _
from explorer import app_settings
-# import the models so that the migration tooling knows the assistant models are part of the explorer app
-from explorer.assistant import models as assistant_models # noqa
from explorer.telemetry import Stat, StatNames
from explorer.utils import (
extract_params, get_params_for_url, get_s3_bucket, passes_blacklist, s3_url,
@@ -21,7 +19,7 @@
# Issue #618. All models must be imported so that Django understands how to manage migrations for the app
from explorer.ee.db_connections.models import DatabaseConnection # noqa
-from explorer.assistant.models import PromptLog # noqa
+from explorer.assistant.models import PromptLog, TableDescription # noqa
MSG_FAILED_BLACKLIST = "Query failed the SQL blacklist: %s"
@@ -58,6 +56,8 @@ class Query(models.Model):
)
)
database_connection = models.ForeignKey(to=DatabaseConnection, on_delete=models.SET_NULL, null=True)
+ few_shot = models.BooleanField(default=False, help_text=_(
+ "Will be included as a good example of SQL in assistant queries that use relevant tables"))
def __init__(self, *args, **kwargs):
self.params = kwargs.get("params")
diff --git a/explorer/src/js/assistant.js b/explorer/src/js/assistant.js
index 9e7c4690..8e7ce159 100644
--- a/explorer/src/js/assistant.js
+++ b/explorer/src/js/assistant.js
@@ -1,69 +1,99 @@
import {getCsrfToken} from "./csrf";
import { marked } from "marked";
import DOMPurify from "dompurify";
-import * as bootstrap from 'bootstrap';
-import List from "list.js";
+import * as bootstrap from "bootstrap";
import { SchemaSvc, getConnElement } from "./schemaService"
+import Choices from "choices.js"
function getErrorMessage() {
const errorElement = document.querySelector('.alert-danger.db-error');
return errorElement ? errorElement.textContent.trim() : null;
}
+function debounce(func, delay) {
+ let timeout;
+ return function(...args) {
+ clearTimeout(timeout);
+ timeout = setTimeout(() => func.apply(this, args), delay);
+ };
+}
+
function setupTableList() {
+
+ if(window.assistantChoices) {
+ window.assistantChoices.destroy();
+ }
+
SchemaSvc.get().then(schema => {
const keys = Object.keys(schema);
- const tableList = document.getElementById('table-list');
- tableList.innerHTML = '';
+ const selectElement = document.createElement('select');
+ selectElement.className = 'js-choice';
+ selectElement.toggleAttribute('multiple');
+ selectElement.toggleAttribute('data-trigger');
- keys.forEach((key, index) => {
- const div = document.createElement('div');
- div.className = 'form-check';
-
- const input = document.createElement('input');
- input.className = 'form-check-input table-checkbox';
- input.type = 'checkbox';
- input.value = key;
- input.id = 'flexCheckDefault' + index;
+ keys.forEach((key) => {
+ const option = document.createElement('option');
+ option.value = key;
+ option.textContent = key;
+ selectElement.appendChild(option);
+ });
- const label = document.createElement('label');
- label.className = 'form-check-label';
- label.setAttribute('for', input.id);
- label.textContent = key;
+ const tableList = document.getElementById('table-list');
+ tableList.innerHTML = '';
+ tableList.appendChild(selectElement);
- div.appendChild(input);
- div.appendChild(label);
- tableList.appendChild(div);
+ const choices = new Choices('.js-choice', {
+ removeItemButton: true,
+ searchEnabled: true,
+ shouldSort: false,
+ position: 'bottom'
});
- let options = {
- valueNames: ['form-check-label'],
- };
-
- new List('additional_table_container', options);
+ // TODO - nasty. Should be refactored. Used by submitAssistantAsk to get relevant tables.
+ window.assistantChoices = choices;
const selectAllButton = document.getElementById('select_all_button');
- const checkboxes = document.querySelectorAll('.table-checkbox');
-
- let selectState = 'all';
-
- selectAllButton.innerHTML = 'Select All';
-
selectAllButton.addEventListener('click', (e) => {
e.preventDefault();
- const isSelectingAll = selectState === 'all';
- checkboxes.forEach((checkbox) => {
- checkbox.checked = isSelectingAll;
+ choices.setChoiceByValue(keys);
+ });
+
+ const deselectAllButton = document.getElementById('deselect_all_button');
+ deselectAllButton.addEventListener('click', (e) => {
+ e.preventDefault();
+ keys.forEach(k => {
+ choices.removeActiveItemsByValue(k);
});
- selectState = isSelectingAll ? 'none' : 'all';
- selectAllButton.innerHTML = isSelectingAll ? 'Deselect All' : 'Select All';
});
+
+ selectRelevantTablesSql(choices, keys);
+
+ document.addEventListener('docChanged', debounce(
+ () => selectRelevantTablesSql(choices, keys), 500));
+
+ document.getElementById('id_assistant_input').addEventListener('input', debounce(
+ () => selectRelevantTablesRequest(choices, keys), 300));
+
})
.catch(error => {
console.error('Error retrieving JSON schema:', error);
});
}
+function selectRelevantTablesSql(choices, keys) {
+ const textContent = window.editor.state.doc.toString();
+ const textWords = new Set(textContent.split(/\s+/));
+ const hasKeys = keys.filter(key => textWords.has(key));
+ choices.setChoiceByValue(hasKeys);
+}
+
+function selectRelevantTablesRequest(choices, keys) {
+ const textContent = document.getElementById("id_assistant_input").value
+ const textWords = new Set(textContent.split(/\s+/));
+ const hasKeys = keys.filter(key => textWords.has(key));
+ choices.setChoiceByValue(hasKeys);
+}
+
export function setUpAssistant(expand = false) {
getConnElement().addEventListener('change', setupTableList);
@@ -71,13 +101,13 @@ export function setUpAssistant(expand = false) {
const error = getErrorMessage();
- if(expand || error) {
+ if (expand || error) {
const myCollapseElement = document.getElementById('assistant_collapse');
const bsCollapse = new bootstrap.Collapse(myCollapseElement, {
- toggle: false
+ toggle: false
});
bsCollapse.show();
- if(error) {
+ if (error) {
document.getElementById('id_error_help_message').classList.remove('d-none');
}
}
@@ -85,7 +115,7 @@ export function setUpAssistant(expand = false) {
const tooltipTriggerList = document.querySelectorAll('[data-bs-toggle="tooltip"]');
[...tooltipTriggerList].map(tooltipTriggerEl => new bootstrap.Tooltip(tooltipTriggerEl));
- document.getElementById('id_assistant_input').addEventListener('keydown', function(event) {
+ document.getElementById('id_assistant_input').addEventListener('keydown', function (event) {
if ((event.ctrlKey || event.metaKey) && (event.key === 'Enter')) {
event.preventDefault();
submitAssistantAsk();
@@ -93,19 +123,105 @@ export function setUpAssistant(expand = false) {
});
document.getElementById('ask_assistant_btn').addEventListener('click', submitAssistantAsk);
+
+ document.getElementById('assistant_history').addEventListener('click', getAssistantHistory);
+
}
-function submitAssistantAsk() {
+function getAssistantHistory() {
+
+ const historyModalId = 'historyModal';
- const selectedTables = Array.from(
- document.querySelectorAll('.table-checkbox:checked')
- ).map(cb => cb.value);
+ // Remove any existing modal with the same ID
+ const existingModal = document.getElementById(historyModalId);
+ if (existingModal) {
+ existingModal.remove()
+ }
+
+ const data = {
+ connection_id: document.getElementById("id_database_connection")?.value ?? null
+ };
+
+ fetch(`${window.baseUrlPath}assistant/history/`, {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ 'X-CSRFToken': getCsrfToken()
+ },
+ body: JSON.stringify(data)
+ })
+ .then(response => {
+ if (!response.ok) {
+ throw new Error('Network response was not ok');
+ }
+ return response.json();
+ })
+ .then(data => {
+ // Create table rows from the fetched data
+ let tableRows = '';
+ data.logs.forEach(log => {
+ let md = DOMPurify.sanitize(marked.parse(log.response));
+ tableRows += `
+
+
${log.user_request}
+
${md}
+
+ `;
+ });
+
+ // Create the complete table HTML
+ const tableHtml = `
+
+
+
+
User Request
+
Response
+
+
+
+ ${tableRows}
+
+
+ `;
+
+ // Insert the table into a new Bootstrap modal
+ const modalHtml = `
+