From 02119fdcb75118799dc17e62d984ad203ebbfdc7 Mon Sep 17 00:00:00 2001 From: Chris Clark Date: Sat, 3 Aug 2024 13:38:09 -0400 Subject: [PATCH 1/4] assistant improvements --- HISTORY.rst | 41 +++- docs/features.rst | 17 +- explorer/admin.py | 2 +- explorer/assistant/forms.py | 42 ++++ explorer/assistant/models.py | 15 ++ explorer/assistant/urls.py | 14 ++ explorer/assistant/utils.py | 165 ++++++++------- explorer/assistant/views.py | 43 +++- explorer/ee/urls.py | 5 - explorer/forms.py | 3 +- explorer/migrations/0026_tabledescription.py | 26 +++ explorer/migrations/0027_query_few_shot.py | 18 ++ explorer/models.py | 6 +- explorer/src/js/assistant.js | 105 +++++----- explorer/src/js/codemirror-config.js | 50 ++++- explorer/src/js/explorer.js | 6 +- explorer/src/js/main.js | 5 +- explorer/src/js/query-list.js | 2 +- explorer/src/js/schemaService.js | 2 +- explorer/src/js/tableDescription.js | 41 ++++ explorer/src/js/uploads.js | 4 +- explorer/src/scss/assistant.scss | 6 - explorer/src/scss/styles.scss | 1 + .../table_description_confirm_delete.html | 14 ++ .../assistant/table_description_form.html | 39 ++++ .../assistant/table_description_list.html | 38 ++++ explorer/templates/explorer/assistant.html | 23 ++- explorer/templates/explorer/base.html | 7 + explorer/templates/explorer/play.html | 11 +- explorer/templates/explorer/query.html | 7 +- explorer/tests/test_assistant.py | 189 ++++++++++-------- explorer/tests/test_views.py | 19 +- explorer/urls.py | 4 +- package-lock.json | 74 ++++++- package.json | 1 + requirements/extra/assistant.txt | 2 - 36 files changed, 771 insertions(+), 276 deletions(-) create mode 100644 explorer/assistant/forms.py create mode 100644 explorer/assistant/urls.py create mode 100644 explorer/migrations/0026_tabledescription.py create mode 100644 explorer/migrations/0027_query_few_shot.py create mode 100644 explorer/src/js/tableDescription.js create mode 100644 explorer/templates/assistant/table_description_confirm_delete.html create mode 100644 explorer/templates/assistant/table_description_form.html create mode 100644 explorer/templates/assistant/table_description_list.html diff --git a/HISTORY.rst b/HISTORY.rst index 21e6d4bb..e76ea74d 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -7,13 +7,40 @@ This project adheres to `Semantic Versioning `_. vNext =========================== -* `#660`_: Userspace connection migration. This should be an invisible change, but represents a significant refactor of how connections function. -Instead of a weird blend of DatabaseConnection models and underlying Django models (which were the original Explorer connections), -this migrates all connections to DatabaseConnection models and implements proper foreign keys to them on the Query and QueryLog models. -A data migration creates new DatabaseConnection models based on the configured settings.EXPLORER_CONNECTIONS. -Going forward, admins can create new Django-backed DatabaseConnection models by registering the connection in EXPLORER_CONNECTIONS, and then creating a -DatabaseConnection model using the Django admin or the user-facing /connections/new/ form, and entering the Django DB alias and setting the connection type to "Django Connection" - +* Keyboard shortcut for formatting the SQL in the editor. + + - Cmd+Shift+F (Windows: Ctrl+Shift+F) + - The format button has been moved tobe a small icon towards the bottom-right of the SQL editor. + +* `#664`_: Improvements to the AI SQL Assistant: + + - Table Annotations: Write persistent table annotations with descriptive information that will get injected into the + prompt for the assistant. For example, if a table is commonly joined to another table through a non-obvious foreign + key, you can tell the assistant about it in plain english, as an annotation to that table. Every time that table is + deemed 'relevant' to an assistant request, that annotation will be included alongside the schema and sample data. + - Few-Shot Examples: Using the small checkbox on the bottom-right of any saved queries, you can designate certain + queries as 'few shot examples". When making an assistant request, any designated few-shot examples that reference + the same tables as your assistant request will get included as 'reference sql' in the prompt for the LLM. + - Autocomplete / multiselect when selecting tables info to send to the SQL Assistant. Much easier and more keyboard + focused. + - Relevant tables are added client-side visually, in real time, based on what's in the SQL editor. The dependency on + sql_metadata is therefore removed, as server-side SQL parsing is no longer necessary + - Improved system prompt that emphasizes the particular SQL dialect being used. + - Addresses issue #657. + +* `#660`_: Userspace connection migration. + + - This should be an invisible change, but represents a significant refactor of how connections function. Instead of a + weird blend of DatabaseConnection models and underlying Django models (which were the original Explorer + connections), this migrates all connections to DatabaseConnection models and implements proper foreign keys to them + on the Query and QueryLog models. A data migration creates new DatabaseConnection models based on the configured + settings.EXPLORER_CONNECTIONS. Going forward, admins can create new Django-backed DatabaseConnection models by + registering the connection in EXPLORER_CONNECTIONS, and then creating a DatabaseConnection model using the Django + admin or the user-facing /connections/new/ form, and entering the Django DB alias and setting the connection type + to "Django Connection". + - The Query.connection and QueryLog.connection fields are deprecated and will be removed in a future release. They + are kept around in this release in case there is an unforeseen issue with the migration. Preserving the fields for + now ensures there is no data loss in the event that a rollback to an earlier version is required. `5.2.0`_ (2024-08-19) =========================== diff --git a/docs/features.rst b/docs/features.rst index a93e9b39..c9dd8f42 100644 --- a/docs/features.rst +++ b/docs/features.rst @@ -5,7 +5,19 @@ SQL Assistant ------------- - Built in integration with OpenAI (or the LLM of your choosing) to quickly get help with your query, with relevant schema - automatically injected into the prompt. Simple, effective. + automatically injected into the prompt. +- The assistant tries hard to get relevant context into the prompt to the LLM, alongside your explicit request. You + can choose tables to include explicitly (and any tables you are reference in your SQL you will see get included as + well). When a table is "included", the prompt will include the schema of the table, 3 sample rows, any Table + Annotations you have added, and any designated "few shot examples". More on each of those below. +- Table Annotations: Write persistent table annotations with descriptive information that will get injected into the + prompt for the assistant. For example, if a table is commonly joined to another table through a non-obvious foreign + key, you can tell the assistant about it in plain english, as an annotation to that table. Every time that table is + deemed 'relevant' to an assistant request, that annotation will be included alongside the schema and sample data. +- Few-shot examples: Using the small checkbox on the bottom-right of any saved query, you can designate queries as + "Assistant Examples". When making an assistant request, the 'included tables' are intersected with tables referenced + by designated Example queries, and those queries are injected into the prompt, and the LLM is told that that these + are good reference queries. Database Support ---------------- @@ -222,8 +234,7 @@ Power tips view. - Command+Enter and Ctrl+Enter will execute a query when typing in the SQL editor area. -- Hit the "Format" button to format and clean up your SQL (this is - non-validating -- just formatting). +- Cmd+Shift+F (Windows: Ctrl+Shift+F) to format the SQL in the editor. - Use the Query Logs feature to share one-time queries that aren't worth creating a persistent query for. Just run your SQL in the playground, then navigate to ``/logs`` and share the link diff --git a/explorer/admin.py b/explorer/admin.py index 0cad25ac..219c87ce 100644 --- a/explorer/admin.py +++ b/explorer/admin.py @@ -7,7 +7,7 @@ @admin.register(Query) class QueryAdmin(admin.ModelAdmin): - list_display = ("title", "description", "created_by_user",) + list_display = ("title", "description", "created_by_user", "few_shot") list_filter = ("title",) raw_id_fields = ("created_by_user",) actions = [generate_report_action()] diff --git a/explorer/assistant/forms.py b/explorer/assistant/forms.py new file mode 100644 index 00000000..f4c69780 --- /dev/null +++ b/explorer/assistant/forms.py @@ -0,0 +1,42 @@ +from django import forms +from explorer.assistant.models import TableDescription +from explorer.ee.db_connections.utils import default_db_connection + + +class TableDescriptionForm(forms.ModelForm): + class Meta: + model = TableDescription + fields = "__all__" + widgets = { + "database_connection": forms.Select(attrs={"class": "form-select"}), + "description": forms.Textarea(attrs={"class": "form-control", "rows": 3}), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self.instance.pk: # Check if this is a new instance + # Set the default value for database_connection + self.fields["database_connection"].initial = default_db_connection() + + if self.instance and self.instance.table_name: + choices = [(self.instance.table_name, self.instance.table_name)] + else: + choices = [] + + f = forms.ChoiceField( + choices=choices, + widget=forms.Select(attrs={"class": "form-select"}) + ) + + # We don't actually care about validating the 'choices' that the ChoiceField does by default. + # Really we are just using that field type in order to get a valid pre-populated Select widget on the client + # But also it can't be blank! + def valid_value_new(v): + return bool(v) + + f.valid_value = valid_value_new + + self.fields["table_name"] = f + + if self.instance and self.instance.table_name: + self.fields["table_name"].initial = self.instance.table_name diff --git a/explorer/assistant/models.py b/explorer/assistant/models.py index 21c6db6d..87d7acff 100644 --- a/explorer/assistant/models.py +++ b/explorer/assistant/models.py @@ -1,5 +1,6 @@ from django.db import models from django.conf import settings +from explorer.ee.db_connections.models import DatabaseConnection class PromptLog(models.Model): @@ -19,3 +20,17 @@ class Meta: duration = models.FloatField(blank=True, null=True) # seconds model = models.CharField(blank=True, max_length=128, default="") error = models.TextField(blank=True, null=True) + + +class TableDescription(models.Model): + + class Meta: + app_label = "explorer" + unique_together = ("database_connection", "table_name") + + database_connection = models.ForeignKey(to=DatabaseConnection, on_delete=models.CASCADE) + table_name = models.CharField(max_length=512) + description = models.TextField() + + def __str__(self): + return f"{self.database_connection.alias} - {self.table_name}" diff --git a/explorer/assistant/urls.py b/explorer/assistant/urls.py new file mode 100644 index 00000000..ece81fbd --- /dev/null +++ b/explorer/assistant/urls.py @@ -0,0 +1,14 @@ +from django.urls import path +from explorer.assistant.views import (TableDescriptionListView, + TableDescriptionCreateView, + TableDescriptionUpdateView, + TableDescriptionDeleteView, + AssistantHelpView) + +assistant_urls = [ + path("assistant/", AssistantHelpView.as_view(), name="assistant"), + path("table-descriptions/", TableDescriptionListView.as_view(), name="table_description_list"), + path("table-descriptions/new/", TableDescriptionCreateView.as_view(), name="table_description_create"), + path("table-descriptions//update/", TableDescriptionUpdateView.as_view(), name="table_description_update"), + path("table-descriptions//delete/", TableDescriptionDeleteView.as_view(), name="table_description_delete"), +] diff --git a/explorer/assistant/utils.py b/explorer/assistant/utils.py index 3dd3cdfb..0a962406 100644 --- a/explorer/assistant/utils.py +++ b/explorer/assistant/utils.py @@ -1,12 +1,16 @@ +from dataclasses import dataclass from explorer import app_settings from explorer.schema import schema_info -from explorer.models import ExplorerValue +from explorer.models import ExplorerValue, Query from django.db.utils import OperationalError +from django.db.models.functions import Lower +from django.db.models import Q +from explorer.assistant.models import TableDescription OPENAI_MODEL = app_settings.EXPLORER_ASSISTANT_MODEL["name"] -ROW_SAMPLE_SIZE = 2 -MAX_FIELD_SAMPLE_SIZE = 500 # characters +ROW_SAMPLE_SIZE = 3 +MAX_FIELD_SAMPLE_SIZE = 200 # characters def openai_client(): @@ -34,25 +38,17 @@ def extract_response(r): return r[-1].content -def tables_from_schema_info(db_connection, table_names): +def table_schema(db_connection, table_name): schema = schema_info(db_connection) - return [table for table in schema if table[0] in set(table_names)] - - -def sample_rows_from_tables(connection, table_names): - ret = "" - for table_name in table_names: - ret += f"SAMPLE FROM TABLE {table_name}:\n" - ret += format_rows_from_table( - sample_rows_from_table(connection, table_name) - ) + "\n\n" - return ret + s = [table for table in schema if table[0] == table_name] + if len(s): + return s[0][1] def sample_rows_from_table(connection, table_name): """ Fetches a sample of rows from the specified table and ensures that any field values - exceeding 500 characters (or bytes) are truncated. This is useful for handling fields + exceeding 200 characters (or bytes) are truncated. This is useful for handling fields like "description" that might contain very long strings of text or binary data. Truncating these fields prevents issues with displaying or processing overly large values. An ellipsis ("...") is appended to indicate that the data has been truncated. @@ -76,8 +72,8 @@ def sample_rows_from_table(connection, table_name): new_val = field if isinstance(field, str) and len(field) > MAX_FIELD_SAMPLE_SIZE: new_val = field[:MAX_FIELD_SAMPLE_SIZE] + "..." # Truncate and add ellipsis - elif isinstance(field, (bytes, bytearray)) and len(field) > MAX_FIELD_SAMPLE_SIZE: - new_val = field[:MAX_FIELD_SAMPLE_SIZE] + b"..." # Truncate binary data + elif isinstance(field, (bytes, bytearray)): + new_val = "" processed_row.append(new_val) ret.append(processed_row) @@ -87,66 +83,97 @@ def sample_rows_from_table(connection, table_name): def format_rows_from_table(rows): - column_headers = list(rows[0]) - ret = " | ".join(column_headers) + "\n" + "-" * 50 + "\n" - for row in rows[1:]: - row_str = " | ".join(str(item) for item in row) - ret += row_str + "\n" - return ret - - -def get_table_names_from_query(sql): - from sql_metadata import Parser - if sql: - try: - parsed = Parser(sql) - return parsed.tables - except ValueError: - return [] - return [] - - -def num_tokens_from_string(string: str) -> int: - """Returns the number of tokens in a text string.""" - import tiktoken - try: - encoding = tiktoken.encoding_for_model(OPENAI_MODEL) - except KeyError: - encoding = tiktoken.get_encoding("cl100k_base") - num_tokens = len(encoding.encode(string)) - return num_tokens + """ Given an array of rows (a list of lists), returns e.g. +AlbumId | Title | ArtistId +1 | For Those About To Rock We Salute You | 1 +2 | Let It Rip | 2 +3 | Restless and Wild | 2 -def fits_in_window(string: str) -> bool: - # Ratchet down by 5% to account for other boilerplate and system prompt - # TODO make this better by actually looking at the token count of the system prompt - return num_tokens_from_string(string) < (app_settings.EXPLORER_ASSISTANT_MODEL["max_tokens"] * 0.95) + """ + return "\n".join([" | ".join([str(item) for item in row]) for row in rows]) -def build_prompt(db_connection, assistant_request, included_tables, query_error=None, sql=None): - user_prompt = "" - djc = db_connection.as_django_connection() - user_prompt += f"## Database Vendor / SQL Flavor is {djc.vendor}\n\n" +def build_system_prompt(flavor): + bsp = ExplorerValue.objects.get_item(ExplorerValue.ASSISTANT_SYSTEM_PROMPT).value + bsp += f"\nYou are an expert at writing SQL, specifically for {flavor}, and account for the nuances of this dialect of SQL. You always respond with valid {flavor} SQL." # noqa + return bsp + + +def get_relevant_annotation(db_connection, t): + return TableDescription.objects.annotate( + table_name_lower=Lower("table_name") + ).filter( + database_connection=db_connection, + table_name_lower=t.lower() + ).first() + + +def get_relevant_few_shots(db_connection, included_tables): + included_tables_lower = [t.lower() for t in included_tables] + + query_conditions = Q() + for table in included_tables_lower: + query_conditions |= Q(sql__icontains=table) - if query_error: - user_prompt += f"## Query Error ##\n\n{query_error}\n\n" + return Query.objects.annotate( + sql_lower=Lower("sql") + ).filter( + database_connection=db_connection, + few_shot=True + ).filter(query_conditions) - if sql: - user_prompt += f"## Existing SQL ##\n\n{sql}\n\n" - results_sample = sample_rows_from_tables(djc, - included_tables) - if fits_in_window(user_prompt + results_sample): - user_prompt += f"## Table Structure with Sampled Data ##\n\n{results_sample}\n\n" - else: # If it's too large with sampling, then provide *just* the structure - table_struct = tables_from_schema_info(db_connection, - included_tables) - user_prompt += f"## Table Structure ##\n\n{table_struct}\n\n" +def get_few_shot_chunk(db_connection, included_tables): + included_tables = [t.lower() for t in included_tables] + few_shot_examples = get_relevant_few_shots(db_connection, included_tables) + if few_shot_examples: + return "## Relevant example queries, written by expert SQL analysts ##\n" + "\n\n".join( + [f"Description: {fs.title} - {fs.description}\nSQL:\n{fs.sql}" + for fs in few_shot_examples.all()] + ) + + +@dataclass +class TablePromptData: + name: str + schema: list + sample: list + annotation: TableDescription + + def render(self): + fmt_schema = "\n".join([str(field) for field in self.schema]) + ret = f"""## Information for Table '{self.name}' ## + +Schema:\n{fmt_schema} + +Sample rows:\n{format_rows_from_table(self.sample)}""" + if self.annotation: + ret += f"\nUsage Notes:\n{self.annotation.description}" + return ret + + +def build_prompt(db_connection, assistant_request, included_tables, query_error=None, sql=None): + included_tables = [t.lower() for t in included_tables] + + error_chunk = f"## Query Error ##\n{query_error}" if query_error else None + sql_chunk = f"## Existing User-Written SQL ##\n{sql}" if sql else None + request_chunk = f"## User's Request to Assistant ##\n{assistant_request}" + table_chunks = [ + TablePromptData( + name=t, + schema=table_schema(db_connection, t), + sample=sample_rows_from_table(db_connection.as_django_connection(), t), + annotation=get_relevant_annotation(db_connection, t) + ).render() + for t in included_tables + ] + few_shot_chunk = get_few_shot_chunk(db_connection, included_tables) - user_prompt += f"## User's Request to Assistant ##\n\n{assistant_request}\n\n" + chunks = [error_chunk, sql_chunk, *table_chunks, few_shot_chunk, request_chunk] prompt = { - "system": ExplorerValue.objects.get_item(ExplorerValue.ASSISTANT_SYSTEM_PROMPT).value, - "user": user_prompt + "system": build_system_prompt(db_connection.as_django_connection().vendor), + "user": "\n\n".join([c for c in chunks if c]), } return prompt diff --git a/explorer/assistant/views.py b/explorer/assistant/views.py index f5d6f2aa..c6507e2c 100644 --- a/explorer/assistant/views.py +++ b/explorer/assistant/views.py @@ -1,14 +1,21 @@ from django.http import JsonResponse from django.views import View from django.utils import timezone +from django.views.generic import ListView, CreateView, UpdateView, DeleteView +from django.urls import reverse_lazy + +from .forms import TableDescriptionForm +from .models import TableDescription + import json +from explorer.views.auth import PermissionRequiredMixin +from explorer.views.mixins import ExplorerContextMixin from explorer.telemetry import Stat, StatNames from explorer.ee.db_connections.models import DatabaseConnection from explorer.assistant.models import PromptLog from explorer.assistant.utils import ( do_req, extract_response, - get_table_names_from_query, build_prompt ) @@ -16,8 +23,7 @@ def run_assistant(request_data, user): sql = request_data.get("sql") - extra_tables = request_data.get("selected_tables", []) - included_tables = get_table_names_from_query(sql) + extra_tables + included_tables = request_data.get("selected_tables", []) connection_id = request_data.get("connection_id") try: @@ -47,7 +53,6 @@ def run_assistant(request_data, user): pl.save() Stat(StatNames.ASSISTANT_RUN, { "included_table_count": len(included_tables), - "extra_table_count": len(extra_tables), "has_sql": bool(sql), "duration": pl.duration, }).track() @@ -67,3 +72,33 @@ def post(self, request, *args, **kwargs): return JsonResponse(response_data) except json.JSONDecodeError: return JsonResponse({"status": "error", "message": "Invalid JSON"}, status=400) + + +class TableDescriptionListView(PermissionRequiredMixin, ExplorerContextMixin, ListView): + model = TableDescription + permission_required = "view_permission" + template_name = "assistant/table_description_list.html" + context_object_name = "table_descriptions" + + +class TableDescriptionCreateView(PermissionRequiredMixin, ExplorerContextMixin, CreateView): + model = TableDescription + permission_required = "change_permission" + template_name = "assistant/table_description_form.html" + success_url = reverse_lazy("table_description_list") + form_class = TableDescriptionForm + + +class TableDescriptionUpdateView(PermissionRequiredMixin, ExplorerContextMixin, UpdateView): + model = TableDescription + permission_required = "change_permission" + template_name = "assistant/table_description_form.html" + success_url = reverse_lazy("table_description_list") + form_class = TableDescriptionForm + + +class TableDescriptionDeleteView(PermissionRequiredMixin, ExplorerContextMixin, DeleteView): + model = TableDescription + permission_required = "change_permission" + template_name = "assistant/table_description_confirm_delete.html" + success_url = reverse_lazy("table_description_list") diff --git a/explorer/ee/urls.py b/explorer/ee/urls.py index 0359d628..e1bfdc92 100644 --- a/explorer/ee/urls.py +++ b/explorer/ee/urls.py @@ -20,12 +20,7 @@ path("connections/create_upload/", DatabaseConnectionUploadCreateView.as_view(), name="explorer_upload_create"), path("connections//edit/", DatabaseConnectionUpdateView.as_view(), name="explorer_connection_update"), path("connections//delete/", DatabaseConnectionDeleteView.as_view(), name="explorer_connection_delete"), - # There are two URLs here because the form can call validate from /connections/new/ or from /connections//edit/ - # which have different relative paths. It's easier to just provide both of these URLs rather than deal with this - # client-side. path("connections/validate/", DatabaseConnectionValidateView.as_view(), name="explorer_connection_validate"), - path("connections//validate/", DatabaseConnectionValidateView.as_view(), - name="explorer_connection_validate_with_pk"), path("connections//refresh/", DatabaseConnectionRefreshView.as_view(), name="explorer_connection_refresh") ] diff --git a/explorer/forms.py b/explorer/forms.py index fb6511c2..e025b558 100644 --- a/explorer/forms.py +++ b/explorer/forms.py @@ -34,6 +34,7 @@ class QueryForm(ModelForm): sql = SqlField() snapshot = BooleanField(widget=CheckboxInput, required=False) + few_shot = BooleanField(widget=CheckboxInput, required=False) database_connection = CharField(widget=Select, required=False) def __init__(self, *args, **kwargs): @@ -77,4 +78,4 @@ def connections(self): class Meta: model = Query - fields = ["title", "sql", "description", "snapshot", "database_connection"] + fields = ["title", "sql", "description", "snapshot", "database_connection", "few_shot"] diff --git a/explorer/migrations/0026_tabledescription.py b/explorer/migrations/0026_tabledescription.py new file mode 100644 index 00000000..782ed1bc --- /dev/null +++ b/explorer/migrations/0026_tabledescription.py @@ -0,0 +1,26 @@ +# Generated by Django 5.0.4 on 2024-08-22 01:24 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('explorer', '0025_alter_query_database_connection_alter_querylog_database_connection'), + ] + + operations = [ + migrations.CreateModel( + name='TableDescription', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('table_name', models.CharField(max_length=512)), + ('description', models.TextField()), + ('database_connection', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='explorer.databaseconnection')), + ], + options={ + 'unique_together': {('database_connection', 'table_name')}, + }, + ), + ] diff --git a/explorer/migrations/0027_query_few_shot.py b/explorer/migrations/0027_query_few_shot.py new file mode 100644 index 00000000..18e901fb --- /dev/null +++ b/explorer/migrations/0027_query_few_shot.py @@ -0,0 +1,18 @@ +# Generated by Django 5.0.4 on 2024-08-25 21:26 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('explorer', '0026_tabledescription'), + ] + + operations = [ + migrations.AddField( + model_name='query', + name='few_shot', + field=models.BooleanField(default=False, help_text='Will be included as a good example of SQL in assistant queries that use relevant tables'), + ), + ] diff --git a/explorer/models.py b/explorer/models.py index 9fe8ef11..67939df8 100644 --- a/explorer/models.py +++ b/explorer/models.py @@ -9,8 +9,6 @@ from django.utils.translation import gettext_lazy as _ from explorer import app_settings -# import the models so that the migration tooling knows the assistant models are part of the explorer app -from explorer.assistant import models as assistant_models # noqa from explorer.telemetry import Stat, StatNames from explorer.utils import ( extract_params, get_params_for_url, get_s3_bucket, passes_blacklist, s3_url, @@ -21,7 +19,7 @@ # Issue #618. All models must be imported so that Django understands how to manage migrations for the app from explorer.ee.db_connections.models import DatabaseConnection # noqa -from explorer.assistant.models import PromptLog # noqa +from explorer.assistant.models import PromptLog, TableDescription # noqa MSG_FAILED_BLACKLIST = "Query failed the SQL blacklist: %s" @@ -58,6 +56,8 @@ class Query(models.Model): ) ) database_connection = models.ForeignKey(to=DatabaseConnection, on_delete=models.SET_NULL, null=True) + few_shot = models.BooleanField(default=False, help_text=_( + "Will be included as a good example of SQL in assistant queries that use relevant tables")) def __init__(self, *args, **kwargs): self.params = kwargs.get("params") diff --git a/explorer/src/js/assistant.js b/explorer/src/js/assistant.js index 9e7c4690..b6e52026 100644 --- a/explorer/src/js/assistant.js +++ b/explorer/src/js/assistant.js @@ -1,9 +1,9 @@ import {getCsrfToken} from "./csrf"; import { marked } from "marked"; import DOMPurify from "dompurify"; -import * as bootstrap from 'bootstrap'; -import List from "list.js"; +import * as bootstrap from "bootstrap"; import { SchemaSvc, getConnElement } from "./schemaService" +import Choices from "choices.js" function getErrorMessage() { const errorElement = document.querySelector('.alert-danger.db-error'); @@ -11,52 +11,59 @@ function getErrorMessage() { } function setupTableList() { + + if(window.assistantChoices) { + window.assistantChoices.destroy(); + } + SchemaSvc.get().then(schema => { const keys = Object.keys(schema); + const selectElement = document.createElement('select'); + selectElement.className = 'js-choice'; + selectElement.toggleAttribute('multiple'); + selectElement.toggleAttribute('data-trigger'); + + keys.forEach((key) => { + const option = document.createElement('option'); + option.value = key; + option.textContent = key; + selectElement.appendChild(option); + }); + const tableList = document.getElementById('table-list'); tableList.innerHTML = ''; - - keys.forEach((key, index) => { - const div = document.createElement('div'); - div.className = 'form-check'; - - const input = document.createElement('input'); - input.className = 'form-check-input table-checkbox'; - input.type = 'checkbox'; - input.value = key; - input.id = 'flexCheckDefault' + index; - - const label = document.createElement('label'); - label.className = 'form-check-label'; - label.setAttribute('for', input.id); - label.textContent = key; - - div.appendChild(input); - div.appendChild(label); - tableList.appendChild(div); + tableList.appendChild(selectElement); + + const choices = new Choices('.js-choice', { + removeItemButton: true, + searchEnabled: true, + shouldSort: false, + placeholder: true, + placeholderValue: 'Relevant tables', + position: 'bottom' }); - let options = { - valueNames: ['form-check-label'], - }; - - new List('additional_table_container', options); + // TODO - nasty. Should be refactored. Used by submitAssistantAsk to get relevant tables. + window.assistantChoices = choices; const selectAllButton = document.getElementById('select_all_button'); - const checkboxes = document.querySelectorAll('.table-checkbox'); - - let selectState = 'all'; - - selectAllButton.innerHTML = 'Select All'; - selectAllButton.addEventListener('click', (e) => { e.preventDefault(); - const isSelectingAll = selectState === 'all'; - checkboxes.forEach((checkbox) => { - checkbox.checked = isSelectingAll; + choices.setChoiceByValue(keys); + }); + + const deselectAllButton = document.getElementById('deselect_all_button'); + deselectAllButton.addEventListener('click', (e) => { + e.preventDefault(); + keys.forEach(k => { + choices.removeActiveItemsByValue(k); }); - selectState = isSelectingAll ? 'none' : 'all'; - selectAllButton.innerHTML = isSelectingAll ? 'Deselect All' : 'Select All'; + }); + + selectRelevantTables(choices, keys); + + document.addEventListener('docChanged', (e) => { + selectRelevantTables(choices, keys); }); }) .catch(error => { @@ -64,6 +71,14 @@ function setupTableList() { }); } +function selectRelevantTables(choices, keys) { + const textContent = window.editor.state.doc.toString(); + const textWords = new Set(textContent.split(/\s+/)); + const hasKeys = keys.filter(key => textWords.has(key)); + choices.setChoiceByValue(hasKeys); +} + + export function setUpAssistant(expand = false) { getConnElement().addEventListener('change', setupTableList); @@ -71,13 +86,13 @@ export function setUpAssistant(expand = false) { const error = getErrorMessage(); - if(expand || error) { + if (expand || error) { const myCollapseElement = document.getElementById('assistant_collapse'); const bsCollapse = new bootstrap.Collapse(myCollapseElement, { - toggle: false + toggle: false }); bsCollapse.show(); - if(error) { + if (error) { document.getElementById('id_error_help_message').classList.remove('d-none'); } } @@ -85,7 +100,7 @@ export function setUpAssistant(expand = false) { const tooltipTriggerList = document.querySelectorAll('[data-bs-toggle="tooltip"]'); [...tooltipTriggerList].map(tooltipTriggerEl => new bootstrap.Tooltip(tooltipTriggerEl)); - document.getElementById('id_assistant_input').addEventListener('keydown', function(event) { + document.getElementById('id_assistant_input').addEventListener('keydown', function (event) { if ((event.ctrlKey || event.metaKey) && (event.key === 'Enter')) { event.preventDefault(); submitAssistantAsk(); @@ -97,15 +112,11 @@ export function setUpAssistant(expand = false) { function submitAssistantAsk() { - const selectedTables = Array.from( - document.querySelectorAll('.table-checkbox:checked') - ).map(cb => cb.value); - const data = { sql: window.editor?.state.doc.toString() ?? null, connection_id: document.getElementById("id_database_connection")?.value ?? null, assistant_request: document.getElementById("id_assistant_input")?.value ?? null, - selected_tables: selectedTables, + selected_tables: assistantChoices.getValue(true), db_error: getErrorMessage() }; @@ -113,7 +124,7 @@ function submitAssistantAsk() { document.getElementById("response_block").classList.remove('d-none'); document.getElementById("assistant_spinner").classList.remove('d-none'); - fetch('../assistant/', { + fetch(`${window.baseUrlPath}assistant/`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/explorer/src/js/codemirror-config.js b/explorer/src/js/codemirror-config.js index a4539e49..08586fe1 100644 --- a/explorer/src/js/codemirror-config.js +++ b/explorer/src/js/codemirror-config.js @@ -14,9 +14,14 @@ import { Prec } from "@codemirror/state"; import {sql} from "@codemirror/lang-sql"; import { SchemaSvc } from "./schemaService" +let debounceTimeout; + let updateListenerExtension = EditorView.updateListener.of((update) => { if (update.docChanged) { - document.dispatchEvent(new CustomEvent('docChanged', {})); + clearTimeout(debounceTimeout); + debounceTimeout = setTimeout(() => { + document.dispatchEvent(new CustomEvent('docChanged', {})); + }, 500); } }); @@ -34,12 +39,21 @@ const hideTooltipOnEsc = EditorView.domEventHandlers({ } }); -function displaySchemaTooltip(editor, content) { +function displaySchemaTooltip(content) { let tooltip = document.getElementById('schema_tooltip'); if (tooltip) { tooltip.classList.remove('d-none'); tooltip.classList.add('d-block'); - tooltip.textContent = content; + + // Clear existing content + tooltip.textContent = ''; + + content.forEach(item => { + let column = document.createElement('span'); + column.textContent = item; + column.classList.add('mx-1') + tooltip.appendChild(column); + }); } } @@ -53,11 +67,11 @@ function fetchAndShowSchema(view) { SchemaSvc.get().then(schema => { let formattedSchema; if (schema.hasOwnProperty(tableName)) { - formattedSchema = JSON.stringify(schema[tableName], null, 2); + displaySchemaTooltip(schema[tableName]); } else { - formattedSchema = `Table '${tableName}' not found in schema for connection`; + const errorMsg = [`Table '${tableName}' not found in schema for connection`]; + displaySchemaTooltip(errorMsg); } - displaySchemaTooltip(view, formattedSchema); }); } return true; @@ -99,6 +113,27 @@ const submitKeymapArr = [ } ] + +const formatEventFromCM = new CustomEvent('formatEventFromCM', {}); +const formatKeymap = [ + { + key: "Ctrl-F", + mac: "Cmd-F", + run: () => { + document.dispatchEvent(formatEventFromCM); + return true; + } + }, + { + key: "Ctrl-F", + mac: "Cmd-F", + run: () => { + document.dispatchEvent(formatEventFromCM); + return true; + } + } +] + const submitKeymap = Prec.highest( keymap.of( submitKeymapArr @@ -136,6 +171,7 @@ export const explorerSetup = (() => [ ...completionKeymap, ...lintKeymap, ...autocompleteKeymap, - ...schemaKeymap + ...schemaKeymap, + ...formatKeymap ]) ])() diff --git a/explorer/src/js/explorer.js b/explorer/src/js/explorer.js index ad113de1..02968d37 100644 --- a/explorer/src/js/explorer.js +++ b/explorer/src/js/explorer.js @@ -89,6 +89,10 @@ export class ExplorerEditor { this.$submit.click(); }); + document.addEventListener('formatEventFromCM', (e) => { + this.formatSql(); + }); + document.addEventListener('docChanged', (e) => { this.docChanged = true; }); @@ -157,7 +161,7 @@ export class ExplorerEditor { formData.append('sql', sqlText); // Append the SQL text to the form data // Make the fetch call - fetch("../format/", { + fetch(`${window.baseUrlPath}format/`, { method: "POST", headers: { // 'Content-Type': 'application/x-www-form-urlencoded', // Not needed when using FormData, as the browser sets it along with the boundary diff --git a/explorer/src/js/main.js b/explorer/src/js/main.js index 4d698497..8676753a 100644 --- a/explorer/src/js/main.js +++ b/explorer/src/js/main.js @@ -18,11 +18,14 @@ const route_initializers = { explorer_schema: () => import('./schema').then(({setupSchema}) => setupSchema()), explorer_upload_create: () => import('./uploads').then(({setupUploads}) => setupUploads()), explorer_connection_update: () => import('./uploads').then(({setupUploads}) => setupUploads()), - explorer_connection_create: () => import('./uploads').then(({setupUploads}) => setupUploads()) + explorer_connection_create: () => import('./uploads').then(({setupUploads}) => setupUploads()), + table_description_create: () => import('./tableDescription').then(({setupTableDescription}) => setupTableDescription()), + table_description_update: () => import('./tableDescription').then(({setupTableDescription}) => setupTableDescription()), }; document.addEventListener('DOMContentLoaded', function() { const clientRoute = document.getElementById('clientRoute').value; + window.baseUrlPath = document.getElementById('baseUrlPath').value; if (route_initializers.hasOwnProperty(clientRoute)) { route_initializers[clientRoute](); } diff --git a/explorer/src/js/query-list.js b/explorer/src/js/query-list.js index 0646dcb8..2ea8045a 100644 --- a/explorer/src/js/query-list.js +++ b/explorer/src/js/query-list.js @@ -60,7 +60,7 @@ function setUpEmailCsv() { } let handleEmailCsvSubmit = function (e) { let email = document.querySelector('#emailCsvInput').value; - let url = '/' + curQueryEmailId + '/email_csv?email=' + email; + let url =`${window.baseUrlPath}${curQueryEmailId}/email_csv?email=${email}`; if (isValidEmail(email)) { fetch(url, { method: 'POST', diff --git a/explorer/src/js/schemaService.js b/explorer/src/js/schemaService.js index 84e9e577..afb4e4f8 100644 --- a/explorer/src/js/schemaService.js +++ b/explorer/src/js/schemaService.js @@ -9,7 +9,7 @@ const fetchSchema = async () => { } try { - const response = await fetch(`../schema.json/${conn}`); + const response = await fetch(`${window.baseUrlPath}schema.json/${conn}`); if (!response.ok) { throw new Error(`HTTP error! Status: ${response.status}`); } diff --git a/explorer/src/js/tableDescription.js b/explorer/src/js/tableDescription.js new file mode 100644 index 00000000..0b59ffe9 --- /dev/null +++ b/explorer/src/js/tableDescription.js @@ -0,0 +1,41 @@ +import {getConnElement, SchemaSvc} from "./schemaService" +import Choices from "choices.js" + + +function populateTableList() { + + if (window.tableChoices) { + window.tableChoices.destroy(); + document.getElementById('id_table_name').innerHTML = ''; + } + + SchemaSvc.get().then(schema => { + + const tables = Object.keys(schema); + const selectElement = document.getElementById('id_table_name'); + selectElement.toggleAttribute('data-trigger'); + + selectElement.appendChild(document.createElement('option')); + + tables.forEach((t) => { + const option = document.createElement('option'); + option.value = t; + option.textContent = t; + selectElement.appendChild(option); + }); + + window.tableChoices = new Choices('#id_table_name', { + searchEnabled: true, + shouldSort: true, + placeholder: true, + placeholderValue: 'Select table', + position: 'bottom' + }); + + }); +} + +export function setupTableDescription() { + getConnElement().addEventListener('change', populateTableList); + populateTableList(); +} diff --git a/explorer/src/js/uploads.js b/explorer/src/js/uploads.js index 6b653a04..508eb9a5 100644 --- a/explorer/src/js/uploads.js +++ b/explorer/src/js/uploads.js @@ -52,7 +52,7 @@ export function setupUploads() { } let xhr = new XMLHttpRequest(); - xhr.open('POST', '../upload/', true); + xhr.open('POST', `${window.baseUrlPath}connections/upload/`, true); xhr.setRequestHeader('X-CSRFToken', getCsrfToken()); xhr.upload.onprogress = function(event) { @@ -91,7 +91,7 @@ export function setupUploads() { let form = document.getElementById("db-connection-form"); let formData = new FormData(form); - fetch("../validate/", { + fetch(`${window.baseUrlPath}validate/`, { method: "POST", body: formData, headers: { diff --git a/explorer/src/scss/assistant.scss b/explorer/src/scss/assistant.scss index fbb2c625..92fdedd2 100644 --- a/explorer/src/scss/assistant.scss +++ b/explorer/src/scss/assistant.scss @@ -14,13 +14,7 @@ cursor: pointer; } -#additional_table_container { - overflow-y: auto; - max-height: 10rem; -} - #assistant_input_parent { max-height: 120px; overflow: hidden; } - diff --git a/explorer/src/scss/styles.scss b/explorer/src/scss/styles.scss index cece367d..9558b24f 100644 --- a/explorer/src/scss/styles.scss +++ b/explorer/src/scss/styles.scss @@ -10,3 +10,4 @@ $bootstrap-icons-font-dir: "../../../node_modules/bootstrap-icons/font/fonts"; @import "assistant"; @import "pivot.css"; +@import "choices.js/public/assets/styles/choices.css"; diff --git a/explorer/templates/assistant/table_description_confirm_delete.html b/explorer/templates/assistant/table_description_confirm_delete.html new file mode 100644 index 00000000..87080208 --- /dev/null +++ b/explorer/templates/assistant/table_description_confirm_delete.html @@ -0,0 +1,14 @@ +{% extends "explorer/base.html" %} + +{% block sql_explorer_content %} +
+

Confirm Delete

+

Are you sure you want to delete the table description for "{{ object.table_name }}" in {{ object.connection.alias }}?

+
+ {% csrf_token %} + + Cancel +
+
+{% endblock %} + diff --git a/explorer/templates/assistant/table_description_form.html b/explorer/templates/assistant/table_description_form.html new file mode 100644 index 00000000..e6d42118 --- /dev/null +++ b/explorer/templates/assistant/table_description_form.html @@ -0,0 +1,39 @@ +{% extends "explorer/base.html" %} + +{% block sql_explorer_content %} +
+

{% if form.instance.pk %}Edit{% else %}Create{% endif %} Table Description

+ {% if form.errors %} + {% for field in form %} + {% for error in field.errors %} +
+ {{ error|escape }} +
+ {% endfor %} + {% endfor %} + {% for error in form.non_field_errors %} +
+ {{ error|escape }} +
+ {% endfor %} +{% endif %} +
+ {% csrf_token %} +
+
+ {{ form.database_connection }} + +
+
+ {{ form.table_name }} +
+
+ {{ form.description }} + +
+
+ + Cancel +
+
+{% endblock %} diff --git a/explorer/templates/assistant/table_description_list.html b/explorer/templates/assistant/table_description_list.html new file mode 100644 index 00000000..1ae59f3d --- /dev/null +++ b/explorer/templates/assistant/table_description_list.html @@ -0,0 +1,38 @@ +{% extends "explorer/base.html" %} + +{% block sql_explorer_content %} +
+

Table Annotations

+

These will be injected into any AI assistant prompts that reference the annotated table.

+
+ Create New + + + + + + + + + + + {% for table_description in table_descriptions %} + + + + + + + {% empty %} + + + + {% endfor %} + +
ConnectionTable NameDescriptionActions
{{ table_description.database_connection }}{{ table_description.table_name }}{{ table_description.description|truncatewords:20 }} + + +
No table descriptions available.
+
+
+{% endblock %} diff --git a/explorer/templates/explorer/assistant.html b/explorer/templates/explorer/assistant.html index beb4deba..f3efd243 100644 --- a/explorer/templates/explorer/assistant.html +++ b/explorer/templates/explorer/assistant.html @@ -10,9 +10,9 @@
-
+
@@ -20,20 +20,23 @@
- - - (?) - - +
+
+ + +
+
+
+
-
+
diff --git a/explorer/templates/explorer/base.html b/explorer/templates/explorer/base.html index 5eb41780..32bac5ba 100644 --- a/explorer/templates/explorer/base.html +++ b/explorer/templates/explorer/base.html @@ -19,6 +19,7 @@ + {% if vite_dev_mode %}

@@ -63,6 +64,12 @@

This is easy to fix, I promise!

{% translate "Connections" %} + {% if assistant_enabled %} + + {% endif %} {% endif %} {% endif %}
{% if query and can_change and tasks_enabled %}{{ form.snapshot }} {% translate "Snapshot" %}{% endif %} + {% if query and can_change and assistant_enabled %}{{ form.few_shot }} {% translate "Assistant Example" %}{% endif %}
{% endblock %} diff --git a/explorer/tests/test_assistant.py b/explorer/tests/test_assistant.py index e9a9484a..fae7df30 100644 --- a/explorer/tests/test_assistant.py +++ b/explorer/tests/test_assistant.py @@ -9,7 +9,15 @@ from django.contrib.auth.models import User from django.db import OperationalError from explorer.ee.db_connections.utils import default_db_connection -from explorer.assistant.utils import sample_rows_from_table, ROW_SAMPLE_SIZE, build_prompt +from explorer.assistant.utils import ( + sample_rows_from_table, + ROW_SAMPLE_SIZE, + build_prompt, + get_relevant_few_shots, + get_relevant_annotation +) + +from explorer.assistant.models import TableDescription def conn(): @@ -31,23 +39,19 @@ def setUp(self): } @patch("explorer.assistant.utils.openai_client") - @patch("explorer.assistant.utils.num_tokens_from_string") - def test_do_modify_query(self, mocked_num_tokens, mocked_openai_client): + def test_do_modify_query(self, mocked_openai_client): from explorer.assistant.views import run_assistant # create.return_value should match: resp.choices[0].message mocked_openai_client.return_value.chat.completions.create.return_value = Mock( choices=[Mock(message=Mock(content="smart computer"))]) - mocked_num_tokens.return_value = 100 resp = run_assistant(self.request_data, None) self.assertEqual(resp, "smart computer") @patch("explorer.assistant.utils.openai_client") - @patch("explorer.assistant.utils.num_tokens_from_string") - def test_assistant_help(self, mocked_num_tokens, mocked_openai_client): + def test_assistant_help(self, mocked_openai_client): mocked_openai_client.return_value.chat.completions.create.return_value = Mock( choices=[Mock(message=Mock(content="smart computer"))]) - mocked_num_tokens.return_value = 100 resp = self.client.post(reverse("assistant"), data=json.dumps(self.request_data), content_type="application/json") @@ -57,38 +61,45 @@ def test_assistant_help(self, mocked_num_tokens, mocked_openai_client): @unittest.skipIf(not app_settings.has_assistant(), "assistant not enabled") class TestBuildPrompt(TestCase): - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=True) @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_vendor_only(self, mock_get_item, mock_fits_in_window, mock_sample_rows): + def test_build_prompt_with_vendor_only(self, mock_get_item): mock_get_item.return_value.value = "system prompt" + result = build_prompt(default_db_connection(), + "Help me with SQL", [], sql="SELECT * FROM table;") + self.assertIn("sqlite", result["system"]) - included_tables = [] + @patch("explorer.assistant.utils.sample_rows_from_table", return_value="sample data") + @patch("explorer.assistant.utils.table_schema", return_value=[]) + @patch("explorer.models.ExplorerValue.objects.get_item") + def test_build_prompt_with_sql_and_annotation(self, mock_get_item, mock_table_schema, mock_sample_rows): + mock_get_item.return_value.value = "system prompt" - result = build_prompt(default_db_connection(), "Help me with SQL", included_tables) - self.assertIn("## Database Vendor / SQL Flavor is sqlite", result["user"]) - self.assertIn("## User's Request to Assistant ##\n\nHelp me with SQL\n\n", result["user"]) - self.assertEqual(result["system"], "system prompt") + included_tables = ["foo"] + td = TableDescription(database_connection=default_db_connection(), table_name="foo", description="annotated") + td.save() - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=True) + result = build_prompt(default_db_connection(), + "Help me with SQL", included_tables, sql="SELECT * FROM table;") + self.assertIn("Usage Notes:\nannotated", result["user"]) + + @patch("explorer.assistant.utils.sample_rows_from_table", return_value="sample data") + @patch("explorer.assistant.utils.table_schema", return_value=[]) @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_sql(self, mock_get_item, mock_fits_in_window, mock_sample_rows): + def test_build_prompt_with_few_shot(self, mock_get_item, mock_table_schema, mock_sample_rows): mock_get_item.return_value.value = "system prompt" - included_tables = [] + included_tables = ["magic"] + SimpleQueryFactory(title="Few shot", description="the quick brown fox", sql="select 'magic value';", + few_shot=True) result = build_prompt(default_db_connection(), "Help me with SQL", included_tables, sql="SELECT * FROM table;") - self.assertIn("## Database Vendor / SQL Flavor is sqlite", result["user"]) - self.assertIn("## Existing SQL ##\n\nSELECT * FROM table;\n\n", result["user"]) - self.assertIn("## User's Request to Assistant ##\n\nHelp me with SQL\n\n", result["user"]) - self.assertEqual(result["system"], "system prompt") + self.assertIn("Relevant example queries", result["user"]) + self.assertIn("magic value", result["user"]) - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=True) + @patch("explorer.assistant.utils.sample_rows_from_table", return_value="sample data") @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_sql_and_error(self, mock_get_item, mock_fits_in_window, mock_sample_rows): + def test_build_prompt_with_sql_and_error(self, mock_get_item, mock_sample_rows): mock_get_item.return_value.value = "system prompt" included_tables = [] @@ -96,45 +107,22 @@ def test_build_prompt_with_sql_and_error(self, mock_get_item, mock_fits_in_windo result = build_prompt(default_db_connection(), "Help me with SQL", included_tables, "Syntax error", "SELECT * FROM table;") - self.assertIn("## Database Vendor / SQL Flavor is sqlite", result["user"]) - self.assertIn("## Existing SQL ##\n\nSELECT * FROM table;\n\n", result["user"]) - self.assertIn("## Query Error ##\n\nSyntax error\n\n", result["user"]) - self.assertIn("## User's Request to Assistant ##\n\nHelp me with SQL\n\n", result["user"]) - self.assertEqual(result["system"], "system prompt") - - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=True) - @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_extra_tables_fitting_window(self, mock_get_item, mock_fits_in_window, mock_sample_rows): - mock_get_item.return_value.value = "system prompt" - - included_tables = ["table1", "table2"] + self.assertIn("## Existing User-Written SQL ##\nSELECT * FROM table;", result["user"]) + self.assertIn("## Query Error ##\nSyntax error\n", result["user"]) + self.assertIn("## User's Request to Assistant ##\nHelp me with SQL", result["user"]) + self.assertIn("system prompt", result["system"]) - result = build_prompt(default_db_connection(), "Help me with SQL", - included_tables, sql="SELECT * FROM table;") - self.assertIn("## Database Vendor / SQL Flavor is sqlite", result["user"]) - self.assertIn("## Existing SQL ##\n\nSELECT * FROM table;\n\n", result["user"]) - self.assertIn("## Table Structure with Sampled Data ##\n\nsample data\n\n", result["user"]) - self.assertIn("## User's Request to Assistant ##\n\nHelp me with SQL\n\n", result["user"]) - self.assertEqual(result["system"], "system prompt") - - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=False) - @patch("explorer.assistant.utils.tables_from_schema_info", return_value="table structure") @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_extra_tables_not_fitting_window(self, mock_get_item, mock_tables_from_schema_info, - mock_fits_in_window, mock_sample_rows): + def test_build_prompt_with_extra_tables_fitting_window(self, mock_get_item): mock_get_item.return_value.value = "system prompt" - included_tables = ["table1", "table2"] + included_tables = ["explorer_query"] + SimpleQueryFactory() result = build_prompt(default_db_connection(), "Help me with SQL", included_tables, sql="SELECT * FROM table;") - self.assertIn("## Database Vendor / SQL Flavor is sqlite", result["user"]) - self.assertIn("## Existing SQL ##\n\nSELECT * FROM table;\n\n", result["user"]) - self.assertIn("## Table Structure ##\n\ntable structure\n\n", result["user"]) - self.assertIn("## User's Request to Assistant ##\n\nHelp me with SQL\n\n", result["user"]) - self.assertEqual(result["system"], "system prompt") + self.assertIn("## Information for Table 'explorer_query' ##", result["user"]) + self.assertIn("Sample rows:\nid | title", result["user"]) @unittest.skipIf(not app_settings.has_assistant(), "assistant not enabled") @@ -161,10 +149,10 @@ def test_truncates_long_strings(self): header, row = ret self.assertEqual(header, ["col1", "col2"]) - self.assertEqual(row[0], "a" * 500 + "...") + self.assertEqual(row[0], "a" * 200 + "...") self.assertEqual(row[1], "short string") - def test_truncates_long_binary_data(self): + def test_binary_data(self): long_binary = b"a" * 600 # Mock database connection and cursor @@ -179,8 +167,8 @@ def test_truncates_long_binary_data(self): header, row = ret self.assertEqual(header, ["col1", "col2"]) - self.assertEqual(row[0], b"a" * 500 + b"...") - self.assertEqual(row[1], b"short binary") + self.assertEqual(row[0], "") + self.assertEqual(row[1], "") def test_handles_various_data_types(self): # Mock database connection and cursor @@ -217,24 +205,12 @@ def test_format_rows_from_table(self): ["val1", "val2"], ] ret = format_rows_from_table(d) - self.assertEqual(ret, "col1 | col2\n" + "-" * 50 + "\nval1 | val2\n") - - def test_parsing_tables_from_query(self): - from explorer.assistant.utils import get_table_names_from_query - sql = "SELECT * FROM explorer_query" - ret = get_table_names_from_query(sql) - self.assertEqual(ret, ["explorer_query"]) - - def test_parsing_tables_from_no_tables(self): - from explorer.assistant.utils import get_table_names_from_query - sql = "select 1;" - ret = get_table_names_from_query(sql) - self.assertEqual(ret, []) + self.assertEqual(ret, "col1 | col2\nval1 | val2") def test_schema_info_from_table_names(self): - from explorer.assistant.utils import tables_from_schema_info - ret = tables_from_schema_info(default_db_connection(), ["explorer_query"]) - expected = [("explorer_query", [ + from explorer.assistant.utils import table_schema + ret = table_schema(default_db_connection(), "explorer_query") + expected = [ ("id", "AutoField"), ("title", "CharField"), ("sql", "TextField"), @@ -244,26 +220,65 @@ def test_schema_info_from_table_names(self): ("created_by_user_id", "IntegerField"), ("snapshot", "BooleanField"), ("connection", "CharField"), - ("database_connection_id", "IntegerField")])] + ("database_connection_id", "IntegerField"), + ("few_shot", "BooleanField")] self.assertEqual(ret, expected) @unittest.skipIf(not app_settings.has_assistant(), "assistant not enabled") class TestAssistantUtils(TestCase): - def test_sample_rows_from_tables(self): - from explorer.assistant.utils import sample_rows_from_tables + def test_sample_rows_from_table(self): + from explorer.assistant.utils import sample_rows_from_table, format_rows_from_table SimpleQueryFactory(title="First Query") SimpleQueryFactory(title="Second Query") QueryLogFactory() - ret = sample_rows_from_tables(conn(), ["explorer_query", "explorer_querylog"]) + ret = sample_rows_from_table(conn(), "explorer_query") + self.assertEqual(len(ret), ROW_SAMPLE_SIZE) + ret = format_rows_from_table(ret) self.assertTrue("First Query" in ret) self.assertTrue("Second Query" in ret) - self.assertTrue("explorer_querylog" in ret) - def test_sample_rows_from_tables_no_tables(self): - from explorer.assistant.utils import sample_rows_from_tables + def test_sample_rows_from_tables_no_table_match(self): + from explorer.assistant.utils import sample_rows_from_table SimpleQueryFactory(title="First Query") SimpleQueryFactory(title="Second Query") - ret = sample_rows_from_tables(conn(), []) - self.assertEqual(ret, "") + ret = sample_rows_from_table(conn(), "banana") + self.assertEqual(ret, [["no such table: banana"]]) + + def test_relevant_few_shots(self): + relevant_q1 = SimpleQueryFactory(sql="select * from relevant_table", few_shot=True) + relevant_q2 = SimpleQueryFactory(sql="select * from conn.RELEVANT_TABLE limit 10", few_shot=True) + irrelevant_q2 = SimpleQueryFactory(sql="select * from conn.RELEVANT_TABLE limit 10", few_shot=False) + relevant_q3 = SimpleQueryFactory(sql="select * from conn.another_good_table limit 10", few_shot=True) + irrelevant_q1 = SimpleQueryFactory(sql="select * from irrelevant_table") + included_tables = ["relevant_table", "ANOTHER_GOOD_TABLE"] + res = get_relevant_few_shots(relevant_q1.database_connection, included_tables) + res_ids = [td.id for td in res] + self.assertIn(relevant_q1.id, res_ids) + self.assertIn(relevant_q2.id, res_ids) + self.assertIn(relevant_q3.id, res_ids) + self.assertNotIn(irrelevant_q1.id, res_ids) + self.assertNotIn(irrelevant_q2.id, res_ids) + + def test_get_relevant_annotations(self): + + relevant1 = TableDescription( + database_connection=default_db_connection(), + table_name="fruit" + ) + relevant2 = TableDescription( + database_connection=default_db_connection(), + table_name="Vegetables" + ) + irrelevant = TableDescription( + database_connection=default_db_connection(), + table_name="animals" + ) + relevant1.save() + relevant2.save() + irrelevant.save() + res1 = get_relevant_annotation(default_db_connection(), "Fruit") + self.assertEqual(relevant1.id, res1.id) + res2 = get_relevant_annotation(default_db_connection(), "vegetables") + self.assertEqual(relevant2.id, res2.id) diff --git a/explorer/tests/test_views.py b/explorer/tests/test_views.py index 883e1528..20299ca7 100644 --- a/explorer/tests/test_views.py +++ b/explorer/tests/test_views.py @@ -23,6 +23,7 @@ from explorer.utils import user_can_see_query from explorer.ee.db_connections.utils import default_db_connection from explorer.schema import connection_schema_cache_key, connection_schema_json_cache_key +from explorer.assistant.models import TableDescription def reload_app_settings(): @@ -1071,12 +1072,6 @@ def test_validate_connection_invalid_form(self): self.assertEqual(response.status_code, 200) self.assertJSONEqual(response.content, {"success": False, "error": "Invalid form data"}) - def test_validate_connection_success_alt_url(self): - url = reverse("explorer_connection_validate_with_pk", args=[1]) - response = self.client.post(url, data=self.valid_data) - self.assertEqual(response.status_code, 200) - self.assertJSONEqual(response.content, {"success": True}) - def test_update_existing_connection(self): DatabaseConnection.objects.create(alias="test_alias", engine="django.db.backends.sqlite3", name=":memory:") response = self.client.post(self.url, data=self.valid_data) @@ -1179,3 +1174,15 @@ def test_database_connection_delete_view(self): def test_database_connection_upload_view(self): response = self.client.get(reverse("explorer_upload_create")) self.assertEqual(response.status_code, 200) + + def test_table_description_list_view(self): + td = TableDescription(database_connection=default_db_connection(), table_name="foo", description="annotated") + td.save() + response = self.client.get(reverse("table_description_list")) + self.assertEqual(response.status_code, 200) + + response = self.client.get(reverse("table_description_update", args=[td.pk])) + self.assertEqual(response.status_code, 200) + + response = self.client.get(reverse("table_description_create")) + self.assertEqual(response.status_code, 200) diff --git a/explorer/urls.py b/explorer/urls.py index 36fec7bb..c612ba5c 100644 --- a/explorer/urls.py +++ b/explorer/urls.py @@ -6,7 +6,7 @@ ListQueryView, PlayQueryView, QueryFavoritesView, QueryFavoriteView, QueryView, SchemaJsonView, SchemaView, StreamQueryView, format_sql ) -from explorer.assistant.views import AssistantHelpView +from explorer.assistant.urls import assistant_urls urlpatterns = [ path( @@ -43,7 +43,7 @@ path("favorites/", QueryFavoritesView.as_view(), name="query_favorites"), path("favorite/", QueryFavoriteView.as_view(), name="query_favorite"), path("", ListQueryView.as_view(), name="explorer_index"), - path("assistant/", AssistantHelpView.as_view(), name="assistant"), ] +urlpatterns += assistant_urls urlpatterns += ee_urls diff --git a/package-lock.json b/package-lock.json index 1cbcab11..03b0fd4c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11,6 +11,7 @@ "@codemirror/language-data": "^6.3.1", "bootstrap": "^5.0.1", "bootstrap-icons": "^1.11.2", + "choices.js": "^10.2.0", "codemirror": "^6.0.1", "cookiejs": "^2.1.3", "dompurify": "^3.0.7", @@ -24,6 +25,20 @@ "vite": "^5.0.13", "vite-plugin-copy": "^0.1.6", "vite-plugin-static-copy": "^1.0.5" + }, + "optionalDependencies": { + "@rollup/rollup-linux-x64-gnu": "^4.18.1" + } + }, + "node_modules/@babel/runtime": { + "version": "7.25.0", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.25.0.tgz", + "integrity": "sha512-7dRy4DwXwtzBrPbZflqxnvfxLF8kdZXPkhymtDeFoFqE6ldzjQFgYTtYIFARcLEYDrqfBfYcZt1WqFxRoyC9Rw==", + "dependencies": { + "regenerator-runtime": "^0.14.0" + }, + "engines": { + "node": ">=6.9.0" } }, "node_modules/@codemirror/autocomplete": { @@ -981,13 +996,12 @@ ] }, "node_modules/@rollup/rollup-linux-x64-gnu": { - "version": "4.9.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.9.1.tgz", - "integrity": "sha512-kr8rEPQ6ns/Lmr/hiw8sEVj9aa07gh1/tQF2Y5HrNCCEPiCBGnBUt9tVusrcBBiJfIt1yNaXN6r1CCmpbFEDpg==", + "version": "4.21.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.21.0.tgz", + "integrity": "sha512-e2hrvElFIh6kW/UNBQK/kzqMNY5mO+67YtEh9OA65RM5IJXYTWiXjX6fjIiPaqOkBthYF1EqgiZ6OXKcQsM0hg==", "cpu": [ "x64" ], - "dev": true, "optional": true, "os": [ "linux" @@ -1112,6 +1126,16 @@ "node": ">=8" } }, + "node_modules/choices.js": { + "version": "10.2.0", + "resolved": "https://registry.npmjs.org/choices.js/-/choices.js-10.2.0.tgz", + "integrity": "sha512-8PKy6wq7BMjNwDTZwr3+Zry6G2+opJaAJDDA/j3yxvqSCnvkKe7ZIFfIyOhoc7htIWFhsfzF9tJpGUATcpUtPg==", + "dependencies": { + "deepmerge": "^4.2.2", + "fuse.js": "^6.6.2", + "redux": "^4.2.0" + } + }, "node_modules/chokidar": { "version": "3.5.3", "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", @@ -1166,6 +1190,14 @@ "resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz", "integrity": "sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g==" }, + "node_modules/deepmerge": { + "version": "4.3.1", + "resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-4.3.1.tgz", + "integrity": "sha512-3sUqbMEc77XqpdNO7FRyRog+eW3ph+GYCbj+rK+uYyRMuwsVy0rMiVtPn+QJlKFvWP/1PYpapqYn0Me2knFn+A==", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/dompurify": { "version": "3.0.7", "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.0.7.tgz", @@ -1274,6 +1306,14 @@ "node": "^8.16.0 || ^10.6.0 || >=11.0.0" } }, + "node_modules/fuse.js": { + "version": "6.6.2", + "resolved": "https://registry.npmjs.org/fuse.js/-/fuse.js-6.6.2.tgz", + "integrity": "sha512-cJaJkxCCxC8qIIcPBF9yGxY0W/tVZS3uEISDxhYIdtk8OL93pe+6Zj7LjCqVV4dzbqcriOZ+kQ/NE4RXZHsIGA==", + "engines": { + "node": ">=10" + } + }, "node_modules/glob-parent": { "version": "5.1.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", @@ -1506,6 +1546,19 @@ "node": ">=8.10.0" } }, + "node_modules/redux": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/redux/-/redux-4.2.1.tgz", + "integrity": "sha512-LAUYz4lc+Do8/g7aeRa8JkyDErK6ekstQaqWQrNRW//MY1TvCEpMtpTWvlQ+FPbWCx+Xixu/6SHt5N0HR+SB4w==", + "dependencies": { + "@babel/runtime": "^7.9.2" + } + }, + "node_modules/regenerator-runtime": { + "version": "0.14.1", + "resolved": "https://registry.npmjs.org/regenerator-runtime/-/regenerator-runtime-0.14.1.tgz", + "integrity": "sha512-dYnhHh0nJoMfnkZs6GmmhFknAGRrLznOu5nc9ML+EJxGvrx6H7teuevqVqCuPcPK//3eDrrjQhehXVx9cnkGdw==" + }, "node_modules/reusify": { "version": "1.0.4", "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz", @@ -1545,6 +1598,19 @@ "fsevents": "~2.3.2" } }, + "node_modules/rollup/node_modules/@rollup/rollup-linux-x64-gnu": { + "version": "4.9.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.9.1.tgz", + "integrity": "sha512-kr8rEPQ6ns/Lmr/hiw8sEVj9aa07gh1/tQF2Y5HrNCCEPiCBGnBUt9tVusrcBBiJfIt1yNaXN6r1CCmpbFEDpg==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ] + }, "node_modules/run-parallel": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", diff --git a/package.json b/package.json index c7e89cad..53833c52 100644 --- a/package.json +++ b/package.json @@ -28,6 +28,7 @@ "@codemirror/language-data": "^6.3.1", "bootstrap": "^5.0.1", "bootstrap-icons": "^1.11.2", + "choices.js": "^10.2.0", "codemirror": "^6.0.1", "cookiejs": "^2.1.3", "dompurify": "^3.0.7", diff --git a/requirements/extra/assistant.txt b/requirements/extra/assistant.txt index c5ae0910..44fbead2 100644 --- a/requirements/extra/assistant.txt +++ b/requirements/extra/assistant.txt @@ -1,3 +1 @@ openai>=1.6.1 -sql_metadata>=2.10 -tiktoken>=0.7 From 84f8ca2282cea9e7cea47ab9a5b89ffdc0ee0535 Mon Sep 17 00:00:00 2001 From: Chris Clark Date: Tue, 27 Aug 2024 09:01:57 -0400 Subject: [PATCH 2/4] various improvements --- HISTORY.rst | 19 ++++--- explorer/ee/db_connections/models.py | 25 ++++++++-- explorer/src/js/assistant.js | 27 ++++++++-- explorer/src/js/codemirror-config.js | 4 -- explorer/src/js/uploads.js | 2 +- explorer/tests/test_models.py | 74 ++++++++++++++++++++++++++++ 6 files changed, 129 insertions(+), 22 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index e76ea74d..7f2b7f31 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -7,11 +7,6 @@ This project adheres to `Semantic Versioning `_. vNext =========================== -* Keyboard shortcut for formatting the SQL in the editor. - - - Cmd+Shift+F (Windows: Ctrl+Shift+F) - - The format button has been moved tobe a small icon towards the bottom-right of the SQL editor. - * `#664`_: Improvements to the AI SQL Assistant: - Table Annotations: Write persistent table annotations with descriptive information that will get injected into the @@ -23,8 +18,9 @@ vNext the same tables as your assistant request will get included as 'reference sql' in the prompt for the LLM. - Autocomplete / multiselect when selecting tables info to send to the SQL Assistant. Much easier and more keyboard focused. - - Relevant tables are added client-side visually, in real time, based on what's in the SQL editor. The dependency on - sql_metadata is therefore removed, as server-side SQL parsing is no longer necessary + - Relevant tables are added client-side visually, in real time, based on what's in the SQL editor and/or any tables + mentioned in the assistant request. The dependency on sql_metadata is therefore removed, as server-side SQL parsing + is no longer necessary. - Improved system prompt that emphasizes the particular SQL dialect being used. - Addresses issue #657. @@ -42,6 +38,13 @@ vNext are kept around in this release in case there is an unforeseen issue with the migration. Preserving the fields for now ensures there is no data loss in the event that a rollback to an earlier version is required. +* Fixed a bug when validating connections to uploaded files. Also added basic locking when downloading files from S3. + +* Keyboard shortcut for formatting the SQL in the editor. + + - Cmd+Shift+F (Windows: Ctrl+Shift+F) + - The format button has been moved tobe a small icon towards the bottom-right of the SQL editor. + `5.2.0`_ (2024-08-19) =========================== * `#651`_: Ability to append an upload to a previously uploaded file/sqlite DB as a new table @@ -670,6 +673,8 @@ Initial Release .. _#651: https://github.com/explorerhq/sql-explorer/pull/651 .. _#659: https://github.com/explorerhq/sql-explorer/pull/659 .. _#662: https://github.com/explorerhq/sql-explorer/pull/662 +.. _#660: https://github.com/explorerhq/sql-explorer/pull/660 +.. _#664: https://github.com/explorerhq/sql-explorer/pull/664 .. _#269: https://github.com/explorerhq/sql-explorer/issues/269 .. _#288: https://github.com/explorerhq/sql-explorer/issues/288 diff --git a/explorer/ee/db_connections/models.py b/explorer/ee/db_connections/models.py index ea81a8f8..140543ac 100644 --- a/explorer/ee/db_connections/models.py +++ b/explorer/ee/db_connections/models.py @@ -4,7 +4,7 @@ from django.db.utils import load_backend from explorer.app_settings import EXPLORER_CONNECTIONS from explorer.ee.db_connections.utils import quick_hash, uploaded_db_local_path - +from django.core.cache import cache from django_cryptography.fields import encrypt @@ -66,12 +66,27 @@ def _download_sqlite(self): s3 = get_s3_bucket() s3.download_file(self.host, self.local_name) + def _download_needed(self): + # If the file doesn't exist, obviously we need to download it + # If it does exist, then check if it's out of date. But only check if in fact the DatabaseConnection has been + # saved to the DB. For example, we might be validating an unsaved connection, in which case the fingerprint + # won't be set yet. + return (not os.path.exists(self.local_name) or + (self.id is not None and self.local_fingerprint() != self.upload_fingerprint)) + def download_sqlite_if_needed(self): - download = not os.path.exists(self.local_name) or self.local_fingerprint() != self.upload_fingerprint - if download: - self._download_sqlite() - self.update_fingerprint() + if self._download_needed(): + cache_key = f"download_lock_{self.local_name}" + lock_acquired = cache.add(cache_key, "locked", timeout=300) # Timeout after 5 minutes + + if lock_acquired: + try: + if self._download_needed(): + self._download_sqlite() + self.update_fingerprint() + finally: + cache.delete(cache_key) @property def is_upload(self): diff --git a/explorer/src/js/assistant.js b/explorer/src/js/assistant.js index b6e52026..1f00dc63 100644 --- a/explorer/src/js/assistant.js +++ b/explorer/src/js/assistant.js @@ -10,6 +10,14 @@ function getErrorMessage() { return errorElement ? errorElement.textContent.trim() : null; } +function debounce(func, delay) { + let timeout; + return function(...args) { + clearTimeout(timeout); + timeout = setTimeout(() => func.apply(this, args), delay); + }; +} + function setupTableList() { if(window.assistantChoices) { @@ -60,24 +68,33 @@ function setupTableList() { }); }); - selectRelevantTables(choices, keys); + selectRelevantTablesSql(choices, keys); + + document.addEventListener('docChanged', debounce( + () => selectRelevantTablesSql(choices, keys), 500)); + + document.getElementById('id_assistant_input').addEventListener('input', debounce( + () => selectRelevantTablesRequest(choices, keys), 300)); - document.addEventListener('docChanged', (e) => { - selectRelevantTables(choices, keys); - }); }) .catch(error => { console.error('Error retrieving JSON schema:', error); }); } -function selectRelevantTables(choices, keys) { +function selectRelevantTablesSql(choices, keys) { const textContent = window.editor.state.doc.toString(); const textWords = new Set(textContent.split(/\s+/)); const hasKeys = keys.filter(key => textWords.has(key)); choices.setChoiceByValue(hasKeys); } +function selectRelevantTablesRequest(choices, keys) { + const textContent = document.getElementById("id_assistant_input").value + const textWords = new Set(textContent.split(/\s+/)); + const hasKeys = keys.filter(key => textWords.has(key)); + choices.setChoiceByValue(hasKeys); +} export function setUpAssistant(expand = false) { diff --git a/explorer/src/js/codemirror-config.js b/explorer/src/js/codemirror-config.js index 08586fe1..3944c522 100644 --- a/explorer/src/js/codemirror-config.js +++ b/explorer/src/js/codemirror-config.js @@ -14,14 +14,10 @@ import { Prec } from "@codemirror/state"; import {sql} from "@codemirror/lang-sql"; import { SchemaSvc } from "./schemaService" -let debounceTimeout; let updateListenerExtension = EditorView.updateListener.of((update) => { if (update.docChanged) { - clearTimeout(debounceTimeout); - debounceTimeout = setTimeout(() => { document.dispatchEvent(new CustomEvent('docChanged', {})); - }, 500); } }); diff --git a/explorer/src/js/uploads.js b/explorer/src/js/uploads.js index 508eb9a5..aa0cd224 100644 --- a/explorer/src/js/uploads.js +++ b/explorer/src/js/uploads.js @@ -91,7 +91,7 @@ export function setupUploads() { let form = document.getElementById("db-connection-form"); let formData = new FormData(form); - fetch(`${window.baseUrlPath}validate/`, { + fetch(`${window.baseUrlPath}connections/validate/`, { method: "POST", body: formData, headers: { diff --git a/explorer/tests/test_models.py b/explorer/tests/test_models.py index 05588f0e..915d3168 100644 --- a/explorer/tests/test_models.py +++ b/explorer/tests/test_models.py @@ -257,6 +257,80 @@ def test_local_name_calls_user_dbs_local_dir(self, mock_getcwd, mock_exists, moc # Ensure os.makedirs was called once since the directory does not exist mock_makedirs.assert_called_once_with("/mocked/path/user_dbs") + @patch("explorer.utils.get_s3_bucket") + @patch("explorer.ee.db_connections.models.cache") + def test_single_download_triggered(self, mock_cache, mock_get_s3_bucket): + # Setup mocks + mock_cache.add.return_value = True # Simulate acquiring the lock + mock_s3 = MagicMock() + mock_get_s3_bucket.return_value = mock_s3 + + # Call the method + instance = DatabaseConnection( + alias="test", + engine=DatabaseConnection.SQLITE, + name="test_db.sqlite3", + host="some-s3-bucket", + id=123 + ) + instance.download_sqlite_if_needed() + + # Assertions + mock_s3.download_file.assert_called_once() + mock_cache.add.assert_called_once() + mock_cache.delete.assert_called_once() + + @patch("explorer.utils.get_s3_bucket") + @patch("explorer.ee.db_connections.models.cache") + def test_skip_download_when_locked(self, mock_cache, mock_get_s3_bucket): + # Setup mocks + mock_cache.add.return_value = False # Simulate that another process has the lock + mock_s3 = MagicMock() + mock_get_s3_bucket.return_value = mock_s3 + + # Call the method + instance = DatabaseConnection( + alias="test", + engine=DatabaseConnection.SQLITE, + name="test_db.sqlite3", + host="some-s3-bucket", + id=123 + ) + instance.download_sqlite_if_needed() + + # Assertions + mock_s3.download_file.assert_not_called() + mock_cache.add.assert_called_once() + mock_cache.delete.assert_not_called() + + @patch("explorer.utils.get_s3_bucket") + def test_not_downloaded_if_file_exists_and_model_is_unsaved(self, mock_get_s3_bucket): + + # Note this is NOT being saved to disk, e.g. how a DatabaseValidate would work. + connection = DatabaseConnection( + alias="test", + engine=DatabaseConnection.SQLITE, + name="test_db.sqlite3", + host="some-s3-bucket", + ) + + def mock_download_file(path, filename): pass + mock_s3 = mock_get_s3_bucket.return_value + mock_s3.download_file = MagicMock(side_effect=mock_download_file) + + # write the file + with open(connection.local_name, "w") as f: + f.write("Initial content") + + # See if it downloads + connection.download_sqlite_if_needed() + + # And it shouldn't.... + mock_s3.download_file.assert_not_called() + + # ...even though the fingerprints don't match + self.assertIsNone(connection.upload_fingerprint) + @patch("explorer.utils.get_s3_bucket") def test_fingerprint_is_updated_after_download_and_download_is_not_called_again(self, mock_get_s3_bucket): # Setup From 20b54248cf784d79844a3096124519959f835bc4 Mon Sep 17 00:00:00 2001 From: Chris Clark Date: Wed, 28 Aug 2024 09:23:36 -0400 Subject: [PATCH 3/4] assistant history --- explorer/app_settings.py | 6 ++ explorer/assistant/models.py | 2 + explorer/assistant/urls.py | 4 +- explorer/assistant/views.py | 24 ++++- ...abase_connection_promptlog_user_request.py | 24 +++++ explorer/src/js/assistant.js | 90 +++++++++++++++++++ explorer/templates/explorer/assistant.html | 38 ++++---- test_project/settings.py | 1 + 8 files changed, 166 insertions(+), 23 deletions(-) create mode 100644 explorer/migrations/0028_promptlog_database_connection_promptlog_user_request.py diff --git a/explorer/app_settings.py b/explorer/app_settings.py index a93257b3..8bf69b63 100644 --- a/explorer/app_settings.py +++ b/explorer/app_settings.py @@ -152,11 +152,17 @@ EXPLORER_AI_API_KEY = getattr(settings, "EXPLORER_AI_API_KEY", None) EXPLORER_ASSISTANT_BASE_URL = getattr(settings, "EXPLORER_ASSISTANT_BASE_URL", "https://api.openai.com/v1") + +# Deprecated. Will be removed in a future release. Please use EXPLORER_ASSISTANT_MODEL_NAME instead EXPLORER_ASSISTANT_MODEL = getattr(settings, "EXPLORER_ASSISTANT_MODEL", # Return the model name and max_tokens it supports {"name": "gpt-4o", "max_tokens": 128000}) +EXPLORER_ASSISTANT_MODEL_NAME = getattr(settings, "EXPLORER_ASSISTANT_MODEL_NAME", + EXPLORER_ASSISTANT_MODEL["name"]) + + EXPLORER_DB_CONNECTIONS_ENABLED = getattr(settings, "EXPLORER_DB_CONNECTIONS_ENABLED", False) EXPLORER_USER_UPLOADS_ENABLED = getattr(settings, "EXPLORER_USER_UPLOADS_ENABLED", False) EXPLORER_PRUNE_LOCAL_UPLOAD_COPY_DAYS_INACTIVITY = getattr(settings, diff --git a/explorer/assistant/models.py b/explorer/assistant/models.py index 87d7acff..556148c1 100644 --- a/explorer/assistant/models.py +++ b/explorer/assistant/models.py @@ -9,6 +9,7 @@ class Meta: app_label = "explorer" prompt = models.TextField(blank=True) + user_request = models.TextField(blank=True) response = models.TextField(blank=True) run_by_user = models.ForeignKey( settings.AUTH_USER_MODEL, @@ -20,6 +21,7 @@ class Meta: duration = models.FloatField(blank=True, null=True) # seconds model = models.CharField(blank=True, max_length=128, default="") error = models.TextField(blank=True, null=True) + database_connection = models.ForeignKey(to=DatabaseConnection, on_delete=models.SET_NULL, blank=True, null=True) class TableDescription(models.Model): diff --git a/explorer/assistant/urls.py b/explorer/assistant/urls.py index ece81fbd..bf6e479b 100644 --- a/explorer/assistant/urls.py +++ b/explorer/assistant/urls.py @@ -3,10 +3,12 @@ TableDescriptionCreateView, TableDescriptionUpdateView, TableDescriptionDeleteView, - AssistantHelpView) + AssistantHelpView, + AssistantHistoryApiView) assistant_urls = [ path("assistant/", AssistantHelpView.as_view(), name="assistant"), + path("assistant/history/", AssistantHistoryApiView.as_view(), name="assistant_history"), path("table-descriptions/", TableDescriptionListView.as_view(), name="table_description_list"), path("table-descriptions/new/", TableDescriptionCreateView.as_view(), name="table_description_create"), path("table-descriptions//update/", TableDescriptionUpdateView.as_view(), name="table_description_update"), diff --git a/explorer/assistant/views.py b/explorer/assistant/views.py index c6507e2c..2571e216 100644 --- a/explorer/assistant/views.py +++ b/explorer/assistant/views.py @@ -30,8 +30,8 @@ def run_assistant(request_data, user): conn = DatabaseConnection.objects.get(id=connection_id) except DatabaseConnection.DoesNotExist: return "Error: Connection not found" - - prompt = build_prompt(conn, request_data.get("assistant_request"), + assistant_request = request_data.get("assistant_request") + prompt = build_prompt(conn, assistant_request, included_tables, request_data.get("db_error"), request_data.get("sql")) start = timezone.now() @@ -39,6 +39,8 @@ def run_assistant(request_data, user): prompt=prompt, run_by_user=user, run_at=timezone.now(), + user_request=assistant_request, + database_connection=conn ) response_text = None try: @@ -102,3 +104,21 @@ class TableDescriptionDeleteView(PermissionRequiredMixin, ExplorerContextMixin, permission_required = "change_permission" template_name = "assistant/table_description_confirm_delete.html" success_url = reverse_lazy("table_description_list") + + +class AssistantHistoryApiView(View): + + def post(self, request, *args, **kwargs): + try: + data = json.loads(request.body) + logs = PromptLog.objects.filter( + run_by_user=request.user, + database_connection_id=data["connection_id"] + ).order_by("-run_at")[:5] + ret = [{ + "user_request": log.user_request, + "response": log.response + } for log in logs] + return JsonResponse({"logs": ret}) + except json.JSONDecodeError: + return JsonResponse({"status": "error", "message": "Invalid JSON"}, status=400) diff --git a/explorer/migrations/0028_promptlog_database_connection_promptlog_user_request.py b/explorer/migrations/0028_promptlog_database_connection_promptlog_user_request.py new file mode 100644 index 00000000..dfba7809 --- /dev/null +++ b/explorer/migrations/0028_promptlog_database_connection_promptlog_user_request.py @@ -0,0 +1,24 @@ +# Generated by Django 5.0.4 on 2024-08-27 18:59 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('explorer', '0027_query_few_shot'), + ] + + operations = [ + migrations.AddField( + model_name='promptlog', + name='database_connection', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='explorer.databaseconnection'), + ), + migrations.AddField( + model_name='promptlog', + name='user_request', + field=models.TextField(blank=True), + ), + ] diff --git a/explorer/src/js/assistant.js b/explorer/src/js/assistant.js index 1f00dc63..eb052c3c 100644 --- a/explorer/src/js/assistant.js +++ b/explorer/src/js/assistant.js @@ -125,8 +125,98 @@ export function setUpAssistant(expand = false) { }); document.getElementById('ask_assistant_btn').addEventListener('click', submitAssistantAsk); + + document.getElementById('assistant_history').addEventListener('click', getAssistantHistory); + +} + +function getAssistantHistory() { + + // Remove any existing modal with the same ID + const existingModal = document.getElementById('historyModal'); + if (existingModal) { + existingModal.remove() + } + + const data = { + connection_id: document.getElementById("id_database_connection")?.value ?? null + }; + + fetch(`${window.baseUrlPath}assistant/history/`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-CSRFToken': getCsrfToken() + }, + body: JSON.stringify(data) + }) + .then(response => { + if (!response.ok) { + throw new Error('Network response was not ok'); + } + return response.json(); + }) + .then(data => { + // Create table rows from the fetched data + let tableRows = ''; + data.logs.forEach(log => { + let md = DOMPurify.sanitize(marked.parse(log.response)); + tableRows += ` + + ${log.user_request} + ${md} + + `; + }); + + // Create the complete table HTML + const tableHtml = ` + + + + + + + + + ${tableRows} + +
User RequestResponse
+ `; + + // Insert the table into a new Bootstrap modal + const modalHtml = ` + + `; + + // Append the modal to the body + document.body.insertAdjacentHTML('beforeend', modalHtml); + + // Show the modal + const historyModal = new bootstrap.Modal(document.getElementById('historyModal')); + historyModal.show(); + }) + .catch(error => { + console.error('There was a problem with the fetch operation:', error); + }); } + function submitAssistantAsk() { const data = { diff --git a/explorer/templates/explorer/assistant.html b/explorer/templates/explorer/assistant.html index f3efd243..9088059d 100644 --- a/explorer/templates/explorer/assistant.html +++ b/explorer/templates/explorer/assistant.html @@ -9,6 +9,14 @@
+
+
+
+

+ Loading... +

+
+