diff --git a/HISTORY.rst b/HISTORY.rst index 21e6d4bb..764b602d 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -7,13 +7,44 @@ This project adheres to `Semantic Versioning `_. vNext =========================== -* `#660`_: Userspace connection migration. This should be an invisible change, but represents a significant refactor of how connections function. -Instead of a weird blend of DatabaseConnection models and underlying Django models (which were the original Explorer connections), -this migrates all connections to DatabaseConnection models and implements proper foreign keys to them on the Query and QueryLog models. -A data migration creates new DatabaseConnection models based on the configured settings.EXPLORER_CONNECTIONS. -Going forward, admins can create new Django-backed DatabaseConnection models by registering the connection in EXPLORER_CONNECTIONS, and then creating a -DatabaseConnection model using the Django admin or the user-facing /connections/new/ form, and entering the Django DB alias and setting the connection type to "Django Connection" - +* `#664`_: Improvements to the AI SQL Assistant: + + - Table Annotations: Write persistent table annotations with descriptive information that will get injected into the + prompt for the assistant. For example, if a table is commonly joined to another table through a non-obvious foreign + key, you can tell the assistant about it in plain english, as an annotation to that table. Every time that table is + deemed 'relevant' to an assistant request, that annotation will be included alongside the schema and sample data. + - Few-Shot Examples: Using the small checkbox on the bottom-right of any saved queries, you can designate certain + queries as 'few shot examples". When making an assistant request, any designated few-shot examples that reference + the same tables as your assistant request will get included as 'reference sql' in the prompt for the LLM. + - Autocomplete / multiselect when selecting tables info to send to the SQL Assistant. Much easier and more keyboard + focused. + - Relevant tables are added client-side visually, in real time, based on what's in the SQL editor and/or any tables + mentioned in the assistant request. The dependency on sql_metadata is therefore removed, as server-side SQL parsing + is no longer necessary. + - Ability to view Assistant request/response history. + - Improved system prompt that emphasizes the particular SQL dialect being used. + - Addresses issue #657. + +* `#660`_: Userspace connection migration. + + - This should be an invisible change, but represents a significant refactor of how connections function. Instead of a + weird blend of DatabaseConnection models and underlying Django models (which were the original Explorer + connections), this migrates all connections to DatabaseConnection models and implements proper foreign keys to them + on the Query and QueryLog models. A data migration creates new DatabaseConnection models based on the configured + settings.EXPLORER_CONNECTIONS. Going forward, admins can create new Django-backed DatabaseConnection models by + registering the connection in EXPLORER_CONNECTIONS, and then creating a DatabaseConnection model using the Django + admin or the user-facing /connections/new/ form, and entering the Django DB alias and setting the connection type + to "Django Connection". + - The Query.connection and QueryLog.connection fields are deprecated and will be removed in a future release. They + are kept around in this release in case there is an unforeseen issue with the migration. Preserving the fields for + now ensures there is no data loss in the event that a rollback to an earlier version is required. + +* Fixed a bug when validating connections to uploaded files. Also added basic locking when downloading files from S3. + +* Keyboard shortcut for formatting the SQL in the editor. + + - Cmd+Shift+F (Windows: Ctrl+Shift+F) + - The format button has been moved tobe a small icon towards the bottom-right of the SQL editor. `5.2.0`_ (2024-08-19) =========================== @@ -643,6 +674,8 @@ Initial Release .. _#651: https://github.com/explorerhq/sql-explorer/pull/651 .. _#659: https://github.com/explorerhq/sql-explorer/pull/659 .. _#662: https://github.com/explorerhq/sql-explorer/pull/662 +.. _#660: https://github.com/explorerhq/sql-explorer/pull/660 +.. _#664: https://github.com/explorerhq/sql-explorer/pull/664 .. _#269: https://github.com/explorerhq/sql-explorer/issues/269 .. _#288: https://github.com/explorerhq/sql-explorer/issues/288 diff --git a/docs/features.rst b/docs/features.rst index a93e9b39..c9dd8f42 100644 --- a/docs/features.rst +++ b/docs/features.rst @@ -5,7 +5,19 @@ SQL Assistant ------------- - Built in integration with OpenAI (or the LLM of your choosing) to quickly get help with your query, with relevant schema - automatically injected into the prompt. Simple, effective. + automatically injected into the prompt. +- The assistant tries hard to get relevant context into the prompt to the LLM, alongside your explicit request. You + can choose tables to include explicitly (and any tables you are reference in your SQL you will see get included as + well). When a table is "included", the prompt will include the schema of the table, 3 sample rows, any Table + Annotations you have added, and any designated "few shot examples". More on each of those below. +- Table Annotations: Write persistent table annotations with descriptive information that will get injected into the + prompt for the assistant. For example, if a table is commonly joined to another table through a non-obvious foreign + key, you can tell the assistant about it in plain english, as an annotation to that table. Every time that table is + deemed 'relevant' to an assistant request, that annotation will be included alongside the schema and sample data. +- Few-shot examples: Using the small checkbox on the bottom-right of any saved query, you can designate queries as + "Assistant Examples". When making an assistant request, the 'included tables' are intersected with tables referenced + by designated Example queries, and those queries are injected into the prompt, and the LLM is told that that these + are good reference queries. Database Support ---------------- @@ -222,8 +234,7 @@ Power tips view. - Command+Enter and Ctrl+Enter will execute a query when typing in the SQL editor area. -- Hit the "Format" button to format and clean up your SQL (this is - non-validating -- just formatting). +- Cmd+Shift+F (Windows: Ctrl+Shift+F) to format the SQL in the editor. - Use the Query Logs feature to share one-time queries that aren't worth creating a persistent query for. Just run your SQL in the playground, then navigate to ``/logs`` and share the link diff --git a/explorer/admin.py b/explorer/admin.py index 0cad25ac..219c87ce 100644 --- a/explorer/admin.py +++ b/explorer/admin.py @@ -7,7 +7,7 @@ @admin.register(Query) class QueryAdmin(admin.ModelAdmin): - list_display = ("title", "description", "created_by_user",) + list_display = ("title", "description", "created_by_user", "few_shot") list_filter = ("title",) raw_id_fields = ("created_by_user",) actions = [generate_report_action()] diff --git a/explorer/app_settings.py b/explorer/app_settings.py index a93257b3..8bf69b63 100644 --- a/explorer/app_settings.py +++ b/explorer/app_settings.py @@ -152,11 +152,17 @@ EXPLORER_AI_API_KEY = getattr(settings, "EXPLORER_AI_API_KEY", None) EXPLORER_ASSISTANT_BASE_URL = getattr(settings, "EXPLORER_ASSISTANT_BASE_URL", "https://api.openai.com/v1") + +# Deprecated. Will be removed in a future release. Please use EXPLORER_ASSISTANT_MODEL_NAME instead EXPLORER_ASSISTANT_MODEL = getattr(settings, "EXPLORER_ASSISTANT_MODEL", # Return the model name and max_tokens it supports {"name": "gpt-4o", "max_tokens": 128000}) +EXPLORER_ASSISTANT_MODEL_NAME = getattr(settings, "EXPLORER_ASSISTANT_MODEL_NAME", + EXPLORER_ASSISTANT_MODEL["name"]) + + EXPLORER_DB_CONNECTIONS_ENABLED = getattr(settings, "EXPLORER_DB_CONNECTIONS_ENABLED", False) EXPLORER_USER_UPLOADS_ENABLED = getattr(settings, "EXPLORER_USER_UPLOADS_ENABLED", False) EXPLORER_PRUNE_LOCAL_UPLOAD_COPY_DAYS_INACTIVITY = getattr(settings, diff --git a/explorer/assistant/forms.py b/explorer/assistant/forms.py new file mode 100644 index 00000000..ddcaffea --- /dev/null +++ b/explorer/assistant/forms.py @@ -0,0 +1,42 @@ +from django import forms +from explorer.assistant.models import TableDescription +from explorer.ee.db_connections.utils import default_db_connection + + +class TableDescriptionForm(forms.ModelForm): + class Meta: + model = TableDescription + fields = "__all__" + widgets = { + "database_connection": forms.Select(attrs={"class": "form-select"}), + "description": forms.Textarea(attrs={"class": "form-control", "rows": 3}), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self.instance.pk: # Check if this is a new instance + # Set the default value for database_connection + self.fields["database_connection"].initial = default_db_connection() + + if self.instance and self.instance.table_name: + choices = [(self.instance.table_name, self.instance.table_name)] + else: + choices = [] + + f = forms.ChoiceField( + choices=choices, + widget=forms.Select(attrs={"class": "form-select", "data-placeholder": "Select table"}) + ) + + # We don't actually care about validating the 'choices' that the ChoiceField does by default. + # Really we are just using that field type in order to get a valid pre-populated Select widget on the client + # But also it can't be blank! + def valid_value_new(v): + return bool(v) + + f.valid_value = valid_value_new + + self.fields["table_name"] = f + + if self.instance and self.instance.table_name: + self.fields["table_name"].initial = self.instance.table_name diff --git a/explorer/assistant/models.py b/explorer/assistant/models.py index 21c6db6d..556148c1 100644 --- a/explorer/assistant/models.py +++ b/explorer/assistant/models.py @@ -1,5 +1,6 @@ from django.db import models from django.conf import settings +from explorer.ee.db_connections.models import DatabaseConnection class PromptLog(models.Model): @@ -8,6 +9,7 @@ class Meta: app_label = "explorer" prompt = models.TextField(blank=True) + user_request = models.TextField(blank=True) response = models.TextField(blank=True) run_by_user = models.ForeignKey( settings.AUTH_USER_MODEL, @@ -19,3 +21,18 @@ class Meta: duration = models.FloatField(blank=True, null=True) # seconds model = models.CharField(blank=True, max_length=128, default="") error = models.TextField(blank=True, null=True) + database_connection = models.ForeignKey(to=DatabaseConnection, on_delete=models.SET_NULL, blank=True, null=True) + + +class TableDescription(models.Model): + + class Meta: + app_label = "explorer" + unique_together = ("database_connection", "table_name") + + database_connection = models.ForeignKey(to=DatabaseConnection, on_delete=models.CASCADE) + table_name = models.CharField(max_length=512) + description = models.TextField() + + def __str__(self): + return f"{self.database_connection.alias} - {self.table_name}" diff --git a/explorer/assistant/urls.py b/explorer/assistant/urls.py new file mode 100644 index 00000000..bf6e479b --- /dev/null +++ b/explorer/assistant/urls.py @@ -0,0 +1,16 @@ +from django.urls import path +from explorer.assistant.views import (TableDescriptionListView, + TableDescriptionCreateView, + TableDescriptionUpdateView, + TableDescriptionDeleteView, + AssistantHelpView, + AssistantHistoryApiView) + +assistant_urls = [ + path("assistant/", AssistantHelpView.as_view(), name="assistant"), + path("assistant/history/", AssistantHistoryApiView.as_view(), name="assistant_history"), + path("table-descriptions/", TableDescriptionListView.as_view(), name="table_description_list"), + path("table-descriptions/new/", TableDescriptionCreateView.as_view(), name="table_description_create"), + path("table-descriptions//update/", TableDescriptionUpdateView.as_view(), name="table_description_update"), + path("table-descriptions//delete/", TableDescriptionDeleteView.as_view(), name="table_description_delete"), +] diff --git a/explorer/assistant/utils.py b/explorer/assistant/utils.py index 3dd3cdfb..0a962406 100644 --- a/explorer/assistant/utils.py +++ b/explorer/assistant/utils.py @@ -1,12 +1,16 @@ +from dataclasses import dataclass from explorer import app_settings from explorer.schema import schema_info -from explorer.models import ExplorerValue +from explorer.models import ExplorerValue, Query from django.db.utils import OperationalError +from django.db.models.functions import Lower +from django.db.models import Q +from explorer.assistant.models import TableDescription OPENAI_MODEL = app_settings.EXPLORER_ASSISTANT_MODEL["name"] -ROW_SAMPLE_SIZE = 2 -MAX_FIELD_SAMPLE_SIZE = 500 # characters +ROW_SAMPLE_SIZE = 3 +MAX_FIELD_SAMPLE_SIZE = 200 # characters def openai_client(): @@ -34,25 +38,17 @@ def extract_response(r): return r[-1].content -def tables_from_schema_info(db_connection, table_names): +def table_schema(db_connection, table_name): schema = schema_info(db_connection) - return [table for table in schema if table[0] in set(table_names)] - - -def sample_rows_from_tables(connection, table_names): - ret = "" - for table_name in table_names: - ret += f"SAMPLE FROM TABLE {table_name}:\n" - ret += format_rows_from_table( - sample_rows_from_table(connection, table_name) - ) + "\n\n" - return ret + s = [table for table in schema if table[0] == table_name] + if len(s): + return s[0][1] def sample_rows_from_table(connection, table_name): """ Fetches a sample of rows from the specified table and ensures that any field values - exceeding 500 characters (or bytes) are truncated. This is useful for handling fields + exceeding 200 characters (or bytes) are truncated. This is useful for handling fields like "description" that might contain very long strings of text or binary data. Truncating these fields prevents issues with displaying or processing overly large values. An ellipsis ("...") is appended to indicate that the data has been truncated. @@ -76,8 +72,8 @@ def sample_rows_from_table(connection, table_name): new_val = field if isinstance(field, str) and len(field) > MAX_FIELD_SAMPLE_SIZE: new_val = field[:MAX_FIELD_SAMPLE_SIZE] + "..." # Truncate and add ellipsis - elif isinstance(field, (bytes, bytearray)) and len(field) > MAX_FIELD_SAMPLE_SIZE: - new_val = field[:MAX_FIELD_SAMPLE_SIZE] + b"..." # Truncate binary data + elif isinstance(field, (bytes, bytearray)): + new_val = "" processed_row.append(new_val) ret.append(processed_row) @@ -87,66 +83,97 @@ def sample_rows_from_table(connection, table_name): def format_rows_from_table(rows): - column_headers = list(rows[0]) - ret = " | ".join(column_headers) + "\n" + "-" * 50 + "\n" - for row in rows[1:]: - row_str = " | ".join(str(item) for item in row) - ret += row_str + "\n" - return ret - - -def get_table_names_from_query(sql): - from sql_metadata import Parser - if sql: - try: - parsed = Parser(sql) - return parsed.tables - except ValueError: - return [] - return [] - - -def num_tokens_from_string(string: str) -> int: - """Returns the number of tokens in a text string.""" - import tiktoken - try: - encoding = tiktoken.encoding_for_model(OPENAI_MODEL) - except KeyError: - encoding = tiktoken.get_encoding("cl100k_base") - num_tokens = len(encoding.encode(string)) - return num_tokens + """ Given an array of rows (a list of lists), returns e.g. +AlbumId | Title | ArtistId +1 | For Those About To Rock We Salute You | 1 +2 | Let It Rip | 2 +3 | Restless and Wild | 2 -def fits_in_window(string: str) -> bool: - # Ratchet down by 5% to account for other boilerplate and system prompt - # TODO make this better by actually looking at the token count of the system prompt - return num_tokens_from_string(string) < (app_settings.EXPLORER_ASSISTANT_MODEL["max_tokens"] * 0.95) + """ + return "\n".join([" | ".join([str(item) for item in row]) for row in rows]) -def build_prompt(db_connection, assistant_request, included_tables, query_error=None, sql=None): - user_prompt = "" - djc = db_connection.as_django_connection() - user_prompt += f"## Database Vendor / SQL Flavor is {djc.vendor}\n\n" +def build_system_prompt(flavor): + bsp = ExplorerValue.objects.get_item(ExplorerValue.ASSISTANT_SYSTEM_PROMPT).value + bsp += f"\nYou are an expert at writing SQL, specifically for {flavor}, and account for the nuances of this dialect of SQL. You always respond with valid {flavor} SQL." # noqa + return bsp + + +def get_relevant_annotation(db_connection, t): + return TableDescription.objects.annotate( + table_name_lower=Lower("table_name") + ).filter( + database_connection=db_connection, + table_name_lower=t.lower() + ).first() + + +def get_relevant_few_shots(db_connection, included_tables): + included_tables_lower = [t.lower() for t in included_tables] + + query_conditions = Q() + for table in included_tables_lower: + query_conditions |= Q(sql__icontains=table) - if query_error: - user_prompt += f"## Query Error ##\n\n{query_error}\n\n" + return Query.objects.annotate( + sql_lower=Lower("sql") + ).filter( + database_connection=db_connection, + few_shot=True + ).filter(query_conditions) - if sql: - user_prompt += f"## Existing SQL ##\n\n{sql}\n\n" - results_sample = sample_rows_from_tables(djc, - included_tables) - if fits_in_window(user_prompt + results_sample): - user_prompt += f"## Table Structure with Sampled Data ##\n\n{results_sample}\n\n" - else: # If it's too large with sampling, then provide *just* the structure - table_struct = tables_from_schema_info(db_connection, - included_tables) - user_prompt += f"## Table Structure ##\n\n{table_struct}\n\n" +def get_few_shot_chunk(db_connection, included_tables): + included_tables = [t.lower() for t in included_tables] + few_shot_examples = get_relevant_few_shots(db_connection, included_tables) + if few_shot_examples: + return "## Relevant example queries, written by expert SQL analysts ##\n" + "\n\n".join( + [f"Description: {fs.title} - {fs.description}\nSQL:\n{fs.sql}" + for fs in few_shot_examples.all()] + ) + + +@dataclass +class TablePromptData: + name: str + schema: list + sample: list + annotation: TableDescription + + def render(self): + fmt_schema = "\n".join([str(field) for field in self.schema]) + ret = f"""## Information for Table '{self.name}' ## + +Schema:\n{fmt_schema} + +Sample rows:\n{format_rows_from_table(self.sample)}""" + if self.annotation: + ret += f"\nUsage Notes:\n{self.annotation.description}" + return ret + + +def build_prompt(db_connection, assistant_request, included_tables, query_error=None, sql=None): + included_tables = [t.lower() for t in included_tables] + + error_chunk = f"## Query Error ##\n{query_error}" if query_error else None + sql_chunk = f"## Existing User-Written SQL ##\n{sql}" if sql else None + request_chunk = f"## User's Request to Assistant ##\n{assistant_request}" + table_chunks = [ + TablePromptData( + name=t, + schema=table_schema(db_connection, t), + sample=sample_rows_from_table(db_connection.as_django_connection(), t), + annotation=get_relevant_annotation(db_connection, t) + ).render() + for t in included_tables + ] + few_shot_chunk = get_few_shot_chunk(db_connection, included_tables) - user_prompt += f"## User's Request to Assistant ##\n\n{assistant_request}\n\n" + chunks = [error_chunk, sql_chunk, *table_chunks, few_shot_chunk, request_chunk] prompt = { - "system": ExplorerValue.objects.get_item(ExplorerValue.ASSISTANT_SYSTEM_PROMPT).value, - "user": user_prompt + "system": build_system_prompt(db_connection.as_django_connection().vendor), + "user": "\n\n".join([c for c in chunks if c]), } return prompt diff --git a/explorer/assistant/views.py b/explorer/assistant/views.py index f5d6f2aa..2571e216 100644 --- a/explorer/assistant/views.py +++ b/explorer/assistant/views.py @@ -1,14 +1,21 @@ from django.http import JsonResponse from django.views import View from django.utils import timezone +from django.views.generic import ListView, CreateView, UpdateView, DeleteView +from django.urls import reverse_lazy + +from .forms import TableDescriptionForm +from .models import TableDescription + import json +from explorer.views.auth import PermissionRequiredMixin +from explorer.views.mixins import ExplorerContextMixin from explorer.telemetry import Stat, StatNames from explorer.ee.db_connections.models import DatabaseConnection from explorer.assistant.models import PromptLog from explorer.assistant.utils import ( do_req, extract_response, - get_table_names_from_query, build_prompt ) @@ -16,16 +23,15 @@ def run_assistant(request_data, user): sql = request_data.get("sql") - extra_tables = request_data.get("selected_tables", []) - included_tables = get_table_names_from_query(sql) + extra_tables + included_tables = request_data.get("selected_tables", []) connection_id = request_data.get("connection_id") try: conn = DatabaseConnection.objects.get(id=connection_id) except DatabaseConnection.DoesNotExist: return "Error: Connection not found" - - prompt = build_prompt(conn, request_data.get("assistant_request"), + assistant_request = request_data.get("assistant_request") + prompt = build_prompt(conn, assistant_request, included_tables, request_data.get("db_error"), request_data.get("sql")) start = timezone.now() @@ -33,6 +39,8 @@ def run_assistant(request_data, user): prompt=prompt, run_by_user=user, run_at=timezone.now(), + user_request=assistant_request, + database_connection=conn ) response_text = None try: @@ -47,7 +55,6 @@ def run_assistant(request_data, user): pl.save() Stat(StatNames.ASSISTANT_RUN, { "included_table_count": len(included_tables), - "extra_table_count": len(extra_tables), "has_sql": bool(sql), "duration": pl.duration, }).track() @@ -67,3 +74,51 @@ def post(self, request, *args, **kwargs): return JsonResponse(response_data) except json.JSONDecodeError: return JsonResponse({"status": "error", "message": "Invalid JSON"}, status=400) + + +class TableDescriptionListView(PermissionRequiredMixin, ExplorerContextMixin, ListView): + model = TableDescription + permission_required = "view_permission" + template_name = "assistant/table_description_list.html" + context_object_name = "table_descriptions" + + +class TableDescriptionCreateView(PermissionRequiredMixin, ExplorerContextMixin, CreateView): + model = TableDescription + permission_required = "change_permission" + template_name = "assistant/table_description_form.html" + success_url = reverse_lazy("table_description_list") + form_class = TableDescriptionForm + + +class TableDescriptionUpdateView(PermissionRequiredMixin, ExplorerContextMixin, UpdateView): + model = TableDescription + permission_required = "change_permission" + template_name = "assistant/table_description_form.html" + success_url = reverse_lazy("table_description_list") + form_class = TableDescriptionForm + + +class TableDescriptionDeleteView(PermissionRequiredMixin, ExplorerContextMixin, DeleteView): + model = TableDescription + permission_required = "change_permission" + template_name = "assistant/table_description_confirm_delete.html" + success_url = reverse_lazy("table_description_list") + + +class AssistantHistoryApiView(View): + + def post(self, request, *args, **kwargs): + try: + data = json.loads(request.body) + logs = PromptLog.objects.filter( + run_by_user=request.user, + database_connection_id=data["connection_id"] + ).order_by("-run_at")[:5] + ret = [{ + "user_request": log.user_request, + "response": log.response + } for log in logs] + return JsonResponse({"logs": ret}) + except json.JSONDecodeError: + return JsonResponse({"status": "error", "message": "Invalid JSON"}, status=400) diff --git a/explorer/ee/db_connections/models.py b/explorer/ee/db_connections/models.py index ea81a8f8..140543ac 100644 --- a/explorer/ee/db_connections/models.py +++ b/explorer/ee/db_connections/models.py @@ -4,7 +4,7 @@ from django.db.utils import load_backend from explorer.app_settings import EXPLORER_CONNECTIONS from explorer.ee.db_connections.utils import quick_hash, uploaded_db_local_path - +from django.core.cache import cache from django_cryptography.fields import encrypt @@ -66,12 +66,27 @@ def _download_sqlite(self): s3 = get_s3_bucket() s3.download_file(self.host, self.local_name) + def _download_needed(self): + # If the file doesn't exist, obviously we need to download it + # If it does exist, then check if it's out of date. But only check if in fact the DatabaseConnection has been + # saved to the DB. For example, we might be validating an unsaved connection, in which case the fingerprint + # won't be set yet. + return (not os.path.exists(self.local_name) or + (self.id is not None and self.local_fingerprint() != self.upload_fingerprint)) + def download_sqlite_if_needed(self): - download = not os.path.exists(self.local_name) or self.local_fingerprint() != self.upload_fingerprint - if download: - self._download_sqlite() - self.update_fingerprint() + if self._download_needed(): + cache_key = f"download_lock_{self.local_name}" + lock_acquired = cache.add(cache_key, "locked", timeout=300) # Timeout after 5 minutes + + if lock_acquired: + try: + if self._download_needed(): + self._download_sqlite() + self.update_fingerprint() + finally: + cache.delete(cache_key) @property def is_upload(self): diff --git a/explorer/ee/urls.py b/explorer/ee/urls.py index 0359d628..e1bfdc92 100644 --- a/explorer/ee/urls.py +++ b/explorer/ee/urls.py @@ -20,12 +20,7 @@ path("connections/create_upload/", DatabaseConnectionUploadCreateView.as_view(), name="explorer_upload_create"), path("connections//edit/", DatabaseConnectionUpdateView.as_view(), name="explorer_connection_update"), path("connections//delete/", DatabaseConnectionDeleteView.as_view(), name="explorer_connection_delete"), - # There are two URLs here because the form can call validate from /connections/new/ or from /connections//edit/ - # which have different relative paths. It's easier to just provide both of these URLs rather than deal with this - # client-side. path("connections/validate/", DatabaseConnectionValidateView.as_view(), name="explorer_connection_validate"), - path("connections//validate/", DatabaseConnectionValidateView.as_view(), - name="explorer_connection_validate_with_pk"), path("connections//refresh/", DatabaseConnectionRefreshView.as_view(), name="explorer_connection_refresh") ] diff --git a/explorer/forms.py b/explorer/forms.py index fb6511c2..e025b558 100644 --- a/explorer/forms.py +++ b/explorer/forms.py @@ -34,6 +34,7 @@ class QueryForm(ModelForm): sql = SqlField() snapshot = BooleanField(widget=CheckboxInput, required=False) + few_shot = BooleanField(widget=CheckboxInput, required=False) database_connection = CharField(widget=Select, required=False) def __init__(self, *args, **kwargs): @@ -77,4 +78,4 @@ def connections(self): class Meta: model = Query - fields = ["title", "sql", "description", "snapshot", "database_connection"] + fields = ["title", "sql", "description", "snapshot", "database_connection", "few_shot"] diff --git a/explorer/migrations/0026_tabledescription.py b/explorer/migrations/0026_tabledescription.py new file mode 100644 index 00000000..782ed1bc --- /dev/null +++ b/explorer/migrations/0026_tabledescription.py @@ -0,0 +1,26 @@ +# Generated by Django 5.0.4 on 2024-08-22 01:24 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('explorer', '0025_alter_query_database_connection_alter_querylog_database_connection'), + ] + + operations = [ + migrations.CreateModel( + name='TableDescription', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('table_name', models.CharField(max_length=512)), + ('description', models.TextField()), + ('database_connection', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='explorer.databaseconnection')), + ], + options={ + 'unique_together': {('database_connection', 'table_name')}, + }, + ), + ] diff --git a/explorer/migrations/0027_query_few_shot.py b/explorer/migrations/0027_query_few_shot.py new file mode 100644 index 00000000..18e901fb --- /dev/null +++ b/explorer/migrations/0027_query_few_shot.py @@ -0,0 +1,18 @@ +# Generated by Django 5.0.4 on 2024-08-25 21:26 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('explorer', '0026_tabledescription'), + ] + + operations = [ + migrations.AddField( + model_name='query', + name='few_shot', + field=models.BooleanField(default=False, help_text='Will be included as a good example of SQL in assistant queries that use relevant tables'), + ), + ] diff --git a/explorer/migrations/0028_promptlog_database_connection_promptlog_user_request.py b/explorer/migrations/0028_promptlog_database_connection_promptlog_user_request.py new file mode 100644 index 00000000..dfba7809 --- /dev/null +++ b/explorer/migrations/0028_promptlog_database_connection_promptlog_user_request.py @@ -0,0 +1,24 @@ +# Generated by Django 5.0.4 on 2024-08-27 18:59 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('explorer', '0027_query_few_shot'), + ] + + operations = [ + migrations.AddField( + model_name='promptlog', + name='database_connection', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='explorer.databaseconnection'), + ), + migrations.AddField( + model_name='promptlog', + name='user_request', + field=models.TextField(blank=True), + ), + ] diff --git a/explorer/models.py b/explorer/models.py index 9fe8ef11..67939df8 100644 --- a/explorer/models.py +++ b/explorer/models.py @@ -9,8 +9,6 @@ from django.utils.translation import gettext_lazy as _ from explorer import app_settings -# import the models so that the migration tooling knows the assistant models are part of the explorer app -from explorer.assistant import models as assistant_models # noqa from explorer.telemetry import Stat, StatNames from explorer.utils import ( extract_params, get_params_for_url, get_s3_bucket, passes_blacklist, s3_url, @@ -21,7 +19,7 @@ # Issue #618. All models must be imported so that Django understands how to manage migrations for the app from explorer.ee.db_connections.models import DatabaseConnection # noqa -from explorer.assistant.models import PromptLog # noqa +from explorer.assistant.models import PromptLog, TableDescription # noqa MSG_FAILED_BLACKLIST = "Query failed the SQL blacklist: %s" @@ -58,6 +56,8 @@ class Query(models.Model): ) ) database_connection = models.ForeignKey(to=DatabaseConnection, on_delete=models.SET_NULL, null=True) + few_shot = models.BooleanField(default=False, help_text=_( + "Will be included as a good example of SQL in assistant queries that use relevant tables")) def __init__(self, *args, **kwargs): self.params = kwargs.get("params") diff --git a/explorer/src/js/assistant.js b/explorer/src/js/assistant.js index 9e7c4690..8e7ce159 100644 --- a/explorer/src/js/assistant.js +++ b/explorer/src/js/assistant.js @@ -1,69 +1,99 @@ import {getCsrfToken} from "./csrf"; import { marked } from "marked"; import DOMPurify from "dompurify"; -import * as bootstrap from 'bootstrap'; -import List from "list.js"; +import * as bootstrap from "bootstrap"; import { SchemaSvc, getConnElement } from "./schemaService" +import Choices from "choices.js" function getErrorMessage() { const errorElement = document.querySelector('.alert-danger.db-error'); return errorElement ? errorElement.textContent.trim() : null; } +function debounce(func, delay) { + let timeout; + return function(...args) { + clearTimeout(timeout); + timeout = setTimeout(() => func.apply(this, args), delay); + }; +} + function setupTableList() { + + if(window.assistantChoices) { + window.assistantChoices.destroy(); + } + SchemaSvc.get().then(schema => { const keys = Object.keys(schema); - const tableList = document.getElementById('table-list'); - tableList.innerHTML = ''; + const selectElement = document.createElement('select'); + selectElement.className = 'js-choice'; + selectElement.toggleAttribute('multiple'); + selectElement.toggleAttribute('data-trigger'); - keys.forEach((key, index) => { - const div = document.createElement('div'); - div.className = 'form-check'; - - const input = document.createElement('input'); - input.className = 'form-check-input table-checkbox'; - input.type = 'checkbox'; - input.value = key; - input.id = 'flexCheckDefault' + index; + keys.forEach((key) => { + const option = document.createElement('option'); + option.value = key; + option.textContent = key; + selectElement.appendChild(option); + }); - const label = document.createElement('label'); - label.className = 'form-check-label'; - label.setAttribute('for', input.id); - label.textContent = key; + const tableList = document.getElementById('table-list'); + tableList.innerHTML = ''; + tableList.appendChild(selectElement); - div.appendChild(input); - div.appendChild(label); - tableList.appendChild(div); + const choices = new Choices('.js-choice', { + removeItemButton: true, + searchEnabled: true, + shouldSort: false, + position: 'bottom' }); - let options = { - valueNames: ['form-check-label'], - }; - - new List('additional_table_container', options); + // TODO - nasty. Should be refactored. Used by submitAssistantAsk to get relevant tables. + window.assistantChoices = choices; const selectAllButton = document.getElementById('select_all_button'); - const checkboxes = document.querySelectorAll('.table-checkbox'); - - let selectState = 'all'; - - selectAllButton.innerHTML = 'Select All'; - selectAllButton.addEventListener('click', (e) => { e.preventDefault(); - const isSelectingAll = selectState === 'all'; - checkboxes.forEach((checkbox) => { - checkbox.checked = isSelectingAll; + choices.setChoiceByValue(keys); + }); + + const deselectAllButton = document.getElementById('deselect_all_button'); + deselectAllButton.addEventListener('click', (e) => { + e.preventDefault(); + keys.forEach(k => { + choices.removeActiveItemsByValue(k); }); - selectState = isSelectingAll ? 'none' : 'all'; - selectAllButton.innerHTML = isSelectingAll ? 'Deselect All' : 'Select All'; }); + + selectRelevantTablesSql(choices, keys); + + document.addEventListener('docChanged', debounce( + () => selectRelevantTablesSql(choices, keys), 500)); + + document.getElementById('id_assistant_input').addEventListener('input', debounce( + () => selectRelevantTablesRequest(choices, keys), 300)); + }) .catch(error => { console.error('Error retrieving JSON schema:', error); }); } +function selectRelevantTablesSql(choices, keys) { + const textContent = window.editor.state.doc.toString(); + const textWords = new Set(textContent.split(/\s+/)); + const hasKeys = keys.filter(key => textWords.has(key)); + choices.setChoiceByValue(hasKeys); +} + +function selectRelevantTablesRequest(choices, keys) { + const textContent = document.getElementById("id_assistant_input").value + const textWords = new Set(textContent.split(/\s+/)); + const hasKeys = keys.filter(key => textWords.has(key)); + choices.setChoiceByValue(hasKeys); +} + export function setUpAssistant(expand = false) { getConnElement().addEventListener('change', setupTableList); @@ -71,13 +101,13 @@ export function setUpAssistant(expand = false) { const error = getErrorMessage(); - if(expand || error) { + if (expand || error) { const myCollapseElement = document.getElementById('assistant_collapse'); const bsCollapse = new bootstrap.Collapse(myCollapseElement, { - toggle: false + toggle: false }); bsCollapse.show(); - if(error) { + if (error) { document.getElementById('id_error_help_message').classList.remove('d-none'); } } @@ -85,7 +115,7 @@ export function setUpAssistant(expand = false) { const tooltipTriggerList = document.querySelectorAll('[data-bs-toggle="tooltip"]'); [...tooltipTriggerList].map(tooltipTriggerEl => new bootstrap.Tooltip(tooltipTriggerEl)); - document.getElementById('id_assistant_input').addEventListener('keydown', function(event) { + document.getElementById('id_assistant_input').addEventListener('keydown', function (event) { if ((event.ctrlKey || event.metaKey) && (event.key === 'Enter')) { event.preventDefault(); submitAssistantAsk(); @@ -93,19 +123,105 @@ export function setUpAssistant(expand = false) { }); document.getElementById('ask_assistant_btn').addEventListener('click', submitAssistantAsk); + + document.getElementById('assistant_history').addEventListener('click', getAssistantHistory); + } -function submitAssistantAsk() { +function getAssistantHistory() { + + const historyModalId = 'historyModal'; - const selectedTables = Array.from( - document.querySelectorAll('.table-checkbox:checked') - ).map(cb => cb.value); + // Remove any existing modal with the same ID + const existingModal = document.getElementById(historyModalId); + if (existingModal) { + existingModal.remove() + } + + const data = { + connection_id: document.getElementById("id_database_connection")?.value ?? null + }; + + fetch(`${window.baseUrlPath}assistant/history/`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-CSRFToken': getCsrfToken() + }, + body: JSON.stringify(data) + }) + .then(response => { + if (!response.ok) { + throw new Error('Network response was not ok'); + } + return response.json(); + }) + .then(data => { + // Create table rows from the fetched data + let tableRows = ''; + data.logs.forEach(log => { + let md = DOMPurify.sanitize(marked.parse(log.response)); + tableRows += ` + + ${log.user_request} + ${md} + + `; + }); + + // Create the complete table HTML + const tableHtml = ` + + + + + + + + + ${tableRows} + +
User RequestResponse
+ `; + + // Insert the table into a new Bootstrap modal + const modalHtml = ` + + `; + + document.body.insertAdjacentHTML('beforeend', modalHtml); + const historyModal = new bootstrap.Modal(document.getElementById(historyModalId)); + historyModal.show(); + + }) + .catch(error => { + console.error('There was a problem with the fetch operation:', error); + }); +} + + +function submitAssistantAsk() { const data = { sql: window.editor?.state.doc.toString() ?? null, connection_id: document.getElementById("id_database_connection")?.value ?? null, assistant_request: document.getElementById("id_assistant_input")?.value ?? null, - selected_tables: selectedTables, + selected_tables: assistantChoices.getValue(true), db_error: getErrorMessage() }; @@ -113,7 +229,7 @@ function submitAssistantAsk() { document.getElementById("response_block").classList.remove('d-none'); document.getElementById("assistant_spinner").classList.remove('d-none'); - fetch('../assistant/', { + fetch(`${window.baseUrlPath}assistant/`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/explorer/src/js/codemirror-config.js b/explorer/src/js/codemirror-config.js index a4539e49..3944c522 100644 --- a/explorer/src/js/codemirror-config.js +++ b/explorer/src/js/codemirror-config.js @@ -14,9 +14,10 @@ import { Prec } from "@codemirror/state"; import {sql} from "@codemirror/lang-sql"; import { SchemaSvc } from "./schemaService" + let updateListenerExtension = EditorView.updateListener.of((update) => { if (update.docChanged) { - document.dispatchEvent(new CustomEvent('docChanged', {})); + document.dispatchEvent(new CustomEvent('docChanged', {})); } }); @@ -34,12 +35,21 @@ const hideTooltipOnEsc = EditorView.domEventHandlers({ } }); -function displaySchemaTooltip(editor, content) { +function displaySchemaTooltip(content) { let tooltip = document.getElementById('schema_tooltip'); if (tooltip) { tooltip.classList.remove('d-none'); tooltip.classList.add('d-block'); - tooltip.textContent = content; + + // Clear existing content + tooltip.textContent = ''; + + content.forEach(item => { + let column = document.createElement('span'); + column.textContent = item; + column.classList.add('mx-1') + tooltip.appendChild(column); + }); } } @@ -53,11 +63,11 @@ function fetchAndShowSchema(view) { SchemaSvc.get().then(schema => { let formattedSchema; if (schema.hasOwnProperty(tableName)) { - formattedSchema = JSON.stringify(schema[tableName], null, 2); + displaySchemaTooltip(schema[tableName]); } else { - formattedSchema = `Table '${tableName}' not found in schema for connection`; + const errorMsg = [`Table '${tableName}' not found in schema for connection`]; + displaySchemaTooltip(errorMsg); } - displaySchemaTooltip(view, formattedSchema); }); } return true; @@ -99,6 +109,27 @@ const submitKeymapArr = [ } ] + +const formatEventFromCM = new CustomEvent('formatEventFromCM', {}); +const formatKeymap = [ + { + key: "Ctrl-F", + mac: "Cmd-F", + run: () => { + document.dispatchEvent(formatEventFromCM); + return true; + } + }, + { + key: "Ctrl-F", + mac: "Cmd-F", + run: () => { + document.dispatchEvent(formatEventFromCM); + return true; + } + } +] + const submitKeymap = Prec.highest( keymap.of( submitKeymapArr @@ -136,6 +167,7 @@ export const explorerSetup = (() => [ ...completionKeymap, ...lintKeymap, ...autocompleteKeymap, - ...schemaKeymap + ...schemaKeymap, + ...formatKeymap ]) ])() diff --git a/explorer/src/js/explorer.js b/explorer/src/js/explorer.js index ad113de1..823d25fd 100644 --- a/explorer/src/js/explorer.js +++ b/explorer/src/js/explorer.js @@ -25,7 +25,7 @@ function updateSchema() { }); }); - $("#schema_frame").attr("src", `../schema/${getConnElement().value}`); + $("#schema_frame").attr("src", `${window.baseUrlPath}schema/${getConnElement().value}`); } @@ -89,6 +89,10 @@ export class ExplorerEditor { this.$submit.click(); }); + document.addEventListener('formatEventFromCM', (e) => { + this.formatSql(); + }); + document.addEventListener('docChanged', (e) => { this.docChanged = true; }); @@ -157,7 +161,7 @@ export class ExplorerEditor { formData.append('sql', sqlText); // Append the SQL text to the form data // Make the fetch call - fetch("../format/", { + fetch(`${window.baseUrlPath}format/`, { method: "POST", headers: { // 'Content-Type': 'application/x-www-form-urlencoded', // Not needed when using FormData, as the browser sets it along with the boundary diff --git a/explorer/src/js/main.js b/explorer/src/js/main.js index 4d698497..8676753a 100644 --- a/explorer/src/js/main.js +++ b/explorer/src/js/main.js @@ -18,11 +18,14 @@ const route_initializers = { explorer_schema: () => import('./schema').then(({setupSchema}) => setupSchema()), explorer_upload_create: () => import('./uploads').then(({setupUploads}) => setupUploads()), explorer_connection_update: () => import('./uploads').then(({setupUploads}) => setupUploads()), - explorer_connection_create: () => import('./uploads').then(({setupUploads}) => setupUploads()) + explorer_connection_create: () => import('./uploads').then(({setupUploads}) => setupUploads()), + table_description_create: () => import('./tableDescription').then(({setupTableDescription}) => setupTableDescription()), + table_description_update: () => import('./tableDescription').then(({setupTableDescription}) => setupTableDescription()), }; document.addEventListener('DOMContentLoaded', function() { const clientRoute = document.getElementById('clientRoute').value; + window.baseUrlPath = document.getElementById('baseUrlPath').value; if (route_initializers.hasOwnProperty(clientRoute)) { route_initializers[clientRoute](); } diff --git a/explorer/src/js/query-list.js b/explorer/src/js/query-list.js index 0646dcb8..2ea8045a 100644 --- a/explorer/src/js/query-list.js +++ b/explorer/src/js/query-list.js @@ -60,7 +60,7 @@ function setUpEmailCsv() { } let handleEmailCsvSubmit = function (e) { let email = document.querySelector('#emailCsvInput').value; - let url = '/' + curQueryEmailId + '/email_csv?email=' + email; + let url =`${window.baseUrlPath}${curQueryEmailId}/email_csv?email=${email}`; if (isValidEmail(email)) { fetch(url, { method: 'POST', diff --git a/explorer/src/js/schemaService.js b/explorer/src/js/schemaService.js index 84e9e577..afb4e4f8 100644 --- a/explorer/src/js/schemaService.js +++ b/explorer/src/js/schemaService.js @@ -9,7 +9,7 @@ const fetchSchema = async () => { } try { - const response = await fetch(`../schema.json/${conn}`); + const response = await fetch(`${window.baseUrlPath}schema.json/${conn}`); if (!response.ok) { throw new Error(`HTTP error! Status: ${response.status}`); } diff --git a/explorer/src/js/tableDescription.js b/explorer/src/js/tableDescription.js new file mode 100644 index 00000000..7aadcbe8 --- /dev/null +++ b/explorer/src/js/tableDescription.js @@ -0,0 +1,48 @@ +import {getConnElement, SchemaSvc} from "./schemaService" +import Choices from "choices.js" + + +function populateTableList() { + + if (window.tableChoices) { + window.tableChoices.destroy(); + document.getElementById('id_table_name').innerHTML = ''; + } + + SchemaSvc.get().then(schema => { + + const tables = Object.keys(schema); + const selectElement = document.getElementById('id_table_name'); + selectElement.toggleAttribute('data-trigger'); + + selectElement.appendChild(document.createElement('option')); + + tables.forEach((t) => { + const option = document.createElement('option'); + option.value = t; + option.textContent = t; + selectElement.appendChild(option); + }); + + window.tableChoices = new Choices('#id_table_name', { + searchEnabled: true, + shouldSort: true, + placeholder: true, + placeholderValue: 'Select table', + position: 'bottom' + }); + + }); +} + +function updateSchema() { + document.getElementById("schema_frame").src = `${window.baseUrlPath}schema/${getConnElement().value}`; +} + +export function setupTableDescription() { + getConnElement().addEventListener('change', populateTableList); + populateTableList(); + + getConnElement().addEventListener('change', updateSchema); + updateSchema(); +} diff --git a/explorer/src/js/uploads.js b/explorer/src/js/uploads.js index 6b653a04..aa0cd224 100644 --- a/explorer/src/js/uploads.js +++ b/explorer/src/js/uploads.js @@ -52,7 +52,7 @@ export function setupUploads() { } let xhr = new XMLHttpRequest(); - xhr.open('POST', '../upload/', true); + xhr.open('POST', `${window.baseUrlPath}connections/upload/`, true); xhr.setRequestHeader('X-CSRFToken', getCsrfToken()); xhr.upload.onprogress = function(event) { @@ -91,7 +91,7 @@ export function setupUploads() { let form = document.getElementById("db-connection-form"); let formData = new FormData(form); - fetch("../validate/", { + fetch(`${window.baseUrlPath}connections/validate/`, { method: "POST", body: formData, headers: { diff --git a/explorer/src/scss/assistant.scss b/explorer/src/scss/assistant.scss index fbb2c625..4f9fa7ea 100644 --- a/explorer/src/scss/assistant.scss +++ b/explorer/src/scss/assistant.scss @@ -14,13 +14,19 @@ cursor: pointer; } -#additional_table_container { - overflow-y: auto; - max-height: 10rem; -} - #assistant_input_parent { max-height: 120px; overflow: hidden; } +.assistant-icons { + width: 1rem; + position: absolute; + right: .75rem; +} + +#table-list { + .choices__inner { + min-height: 7rem !important; + } +} diff --git a/explorer/src/scss/choices.scss b/explorer/src/scss/choices.scss new file mode 100644 index 00000000..dc9ccaf6 --- /dev/null +++ b/explorer/src/scss/choices.scss @@ -0,0 +1,323 @@ +@import "variables"; + +.choices { + position: relative; + overflow: hidden; + margin-bottom: 24px; +} +.choices:focus { + outline: none; +} +.choices:last-child { + margin-bottom: 0; +} +.choices.is-open { + overflow: visible; +} +.choices.is-disabled .choices__inner, +.choices.is-disabled .choices__input { + background-color: #eaeaea; + cursor: not-allowed; + -webkit-user-select: none; + user-select: none; +} +.choices.is-disabled .choices__item { + cursor: not-allowed; +} +.choices [hidden] { + display: none !important; +} + +.choices[data-type*=select-one] { + cursor: pointer; +} +.choices[data-type*=select-one] .choices__inner { + padding-bottom: 7.5px; +} +.choices[data-type*=select-one] .choices__input { + display: block; + width: 100%; + padding: 10px; + border-bottom: 1px solid #ddd; + background-color: #fff; + margin: 0; +} +.choices[data-type*=select-one] .choices__button { + background-image: url("data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMjEiIGhlaWdodD0iMjEiIHZpZXdCb3g9IjAgMCAyMSAyMSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj48ZyBmaWxsPSIjMDAwIiBmaWxsLXJ1bGU9ImV2ZW5vZGQiPjxwYXRoIGQ9Ik0yLjU5Mi4wNDRsMTguMzY0IDE4LjM2NC0yLjU0OCAyLjU0OEwuMDQ0IDIuNTkyeiIvPjxwYXRoIGQ9Ik0wIDE4LjM2NEwxOC4zNjQgMGwyLjU0OCAyLjU0OEwyLjU0OCAyMC45MTJ6Ii8+PC9nPjwvc3ZnPg=="); + padding: 0; + background-size: 8px; + position: absolute; + top: 50%; + right: 0; + margin-top: -10px; + margin-right: 25px; + height: 20px; + width: 20px; + border-radius: 10em; + opacity: 0.25; +} +.choices[data-type*=select-one] .choices__button:hover, .choices[data-type*=select-one] .choices__button:focus { + opacity: 1; +} +.choices[data-type*=select-one] .choices__button:focus { + box-shadow: 0 0 0 2px #005F75; +} +.choices[data-type*=select-one] .choices__item[data-placeholder] .choices__button { + display: none; +} +.choices[data-type*=select-one]::after { + content: ""; + height: 0; + width: 0; + border-style: solid; + border-color: #333 transparent transparent transparent; + border-width: 5px; + position: absolute; + right: 11.5px; + top: 50%; + margin-top: -2.5px; + pointer-events: none; +} +.choices[data-type*=select-one].is-open::after { + border-color: transparent transparent #333; + margin-top: -7.5px; +} +.choices[data-type*=select-one][dir=rtl]::after { + left: 11.5px; + right: auto; +} +.choices[data-type*=select-one][dir=rtl] .choices__button { + right: auto; + left: 0; + margin-left: 25px; + margin-right: 0; +} + +.choices[data-type*=select-multiple] .choices__inner, +.choices[data-type*=text] .choices__inner { + cursor: text; +} +.choices[data-type*=select-multiple] .choices__button, +.choices[data-type*=text] .choices__button { + position: relative; + display: inline-block; + margin-top: 0; + margin-right: -4px; + margin-bottom: 0; + margin-left: 8px; + padding-left: 16px; + border-left: 1px solid #003642; + background-image: url("data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMjEiIGhlaWdodD0iMjEiIHZpZXdCb3g9IjAgMCAyMSAyMSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj48ZyBmaWxsPSIjRkZGIiBmaWxsLXJ1bGU9ImV2ZW5vZGQiPjxwYXRoIGQ9Ik0yLjU5Mi4wNDRsMTguMzY0IDE4LjM2NC0yLjU0OCAyLjU0OEwuMDQ0IDIuNTkyeiIvPjxwYXRoIGQ9Ik0wIDE4LjM2NEwxOC4zNjQgMGwyLjU0OCAyLjU0OEwyLjU0OCAyMC45MTJ6Ii8+PC9nPjwvc3ZnPg=="); + background-size: 8px; + width: 8px; + line-height: 1; + opacity: 0.75; + border-radius: 0; +} +.choices[data-type*=select-multiple] .choices__button:hover, .choices[data-type*=select-multiple] .choices__button:focus, +.choices[data-type*=text] .choices__button:hover, +.choices[data-type*=text] .choices__button:focus { + opacity: 1; +} + +.choices__inner { + display: inline-block; + vertical-align: top; + width: 100%; + background-color: #f9f9f9; + padding: 7.5px 7.5px 3.75px; + border: 1px solid #ddd; + border-radius: 6px; + overflow: hidden; +} +.is-focused .choices__inner, .is-open .choices__inner { + border-color: #b7b7b7; +} +.is-open .choices__inner { + border-radius: 2.5px 2.5px 0 0; +} +.is-flipped.is-open .choices__inner { + border-radius: 0 0 2.5px 2.5px; +} + +.choices__list { + margin: 0; + padding-left: 0; + list-style: none; +} +.choices__list--single { + display: inline-block; + padding: 4px 16px 4px 4px; + width: 100%; +} +[dir=rtl] .choices__list--single { + padding-right: 4px; + padding-left: 16px; +} +.choices__list--single .choices__item { + width: 100%; +} + +.choices__list--multiple { + display: inline; +} +.choices__list--multiple .choices__item { + display: inline-block; + vertical-align: middle; + border-radius: 20px; + padding: 4px 10px; + font-weight: 500; + margin-right: 3.75px; + margin-bottom: 3.75px; + background-color: $dark-lightened; + border: 1px solid $dark; + color: #fff; + word-break: break-all; + box-sizing: border-box; +} +.choices__list--multiple .choices__item[data-deletable] { + padding-right: 5px; +} +[dir=rtl] .choices__list--multiple .choices__item { + margin-right: 0; + margin-left: 3.75px; +} +.is-disabled .choices__list--multiple .choices__item { + background-color: #aaaaaa; + border: 1px solid #919191; +} + +.choices__list--dropdown, .choices__list[aria-expanded] { + display: none; + z-index: 1; + position: absolute; + width: 100%; + background-color: #fff; + border: 1px solid #ddd; + top: 100%; + margin-top: -1px; + border-bottom-left-radius: 2.5px; + border-bottom-right-radius: 2.5px; + overflow: hidden; + word-break: break-all; +} +.is-active.choices__list--dropdown, .is-active.choices__list[aria-expanded] { + display: block; +} +.is-open .choices__list--dropdown, .is-open .choices__list[aria-expanded] { + border-color: #b7b7b7; +} +.is-flipped .choices__list--dropdown, .is-flipped .choices__list[aria-expanded] { + top: auto; + bottom: 100%; + margin-top: 0; + margin-bottom: -1px; + border-radius: 0.25rem 0.25rem 0 0; +} +.choices__list--dropdown .choices__list, .choices__list[aria-expanded] .choices__list { + position: relative; + max-height: 300px; + overflow: auto; + -webkit-overflow-scrolling: touch; + will-change: scroll-position; +} +.choices__list--dropdown .choices__item, .choices__list[aria-expanded] .choices__item { + position: relative; + padding: 10px; +} +[dir=rtl] .choices__list--dropdown .choices__item, [dir=rtl] .choices__list[aria-expanded] .choices__item { + text-align: right; +} +@media (min-width: 640px) { + .choices__list--dropdown .choices__item--selectable[data-select-text], .choices__list[aria-expanded] .choices__item--selectable[data-select-text] { + padding-right: 100px; + } + .choices__list--dropdown .choices__item--selectable[data-select-text]::after, .choices__list[aria-expanded] .choices__item--selectable[data-select-text]::after { + content: attr(data-select-text); + opacity: 0; + position: absolute; + right: 10px; + top: 50%; + transform: translateY(-50%); + } + [dir=rtl] .choices__list--dropdown .choices__item--selectable[data-select-text], [dir=rtl] .choices__list[aria-expanded] .choices__item--selectable[data-select-text] { + text-align: right; + padding-left: 100px; + padding-right: 10px; + } + [dir=rtl] .choices__list--dropdown .choices__item--selectable[data-select-text]::after, [dir=rtl] .choices__list[aria-expanded] .choices__item--selectable[data-select-text]::after { + right: auto; + left: 10px; + } +} +.choices__list--dropdown .choices__item--selectable.is-highlighted, .choices__list[aria-expanded] .choices__item--selectable.is-highlighted { + background-color: #f2f2f2; +} +.choices__list--dropdown .choices__item--selectable.is-highlighted::after, .choices__list[aria-expanded] .choices__item--selectable.is-highlighted::after { + opacity: 0.5; +} + +.choices__item { + cursor: default; +} + +.choices__item--selectable { + cursor: pointer; +} + +.choices__item--disabled { + cursor: not-allowed; + -webkit-user-select: none; + user-select: none; + opacity: 0.5; +} + +.choices__heading { + font-weight: 600; + padding: 10px; + border-bottom: 1px solid #f7f7f7; + color: gray; +} + +.choices__button { + text-indent: -9999px; + appearance: none; + border: 0; + background-color: transparent; + background-repeat: no-repeat; + background-position: center; + cursor: pointer; +} +.choices__button:focus { + outline: none; +} + +.choices__input { + display: inline-block; + vertical-align: baseline; + background-color: #f9f9f9; + margin-bottom: 5px; + border: 0; + border-radius: 0; + max-width: 100%; + padding: 4px 0 4px 2px; +} +.choices__input:focus { + outline: 0; +} +.choices__input::-webkit-search-decoration, .choices__input::-webkit-search-cancel-button, .choices__input::-webkit-search-results-button, .choices__input::-webkit-search-results-decoration { + display: none; +} +.choices__input::-ms-clear, .choices__input::-ms-reveal { + display: none; + width: 0; + height: 0; +} +[dir=rtl] .choices__input { + padding-right: 2px; + padding-left: 0; +} + +.choices__placeholder { + opacity: 0.5; +} diff --git a/explorer/src/scss/styles.scss b/explorer/src/scss/styles.scss index cece367d..39515c17 100644 --- a/explorer/src/scss/styles.scss +++ b/explorer/src/scss/styles.scss @@ -5,8 +5,8 @@ $bootstrap-icons-font-dir: "../../../node_modules/bootstrap-icons/font/fonts"; @import "~bootstrap-icons/font/bootstrap-icons.css"; - @import "explorer"; @import "assistant"; +@import "choices"; @import "pivot.css"; diff --git a/explorer/src/scss/variables.scss b/explorer/src/scss/variables.scss index 6397f6dc..e2cae45e 100644 --- a/explorer/src/scss/variables.scss +++ b/explorer/src/scss/variables.scss @@ -5,6 +5,7 @@ $blue: rgb(3, 68, 220); $green: rgb(127, 176, 105); $primary: $blue; $dark: rgb(1, 32, 63); +$dark-lightened: rgba(1, 32, 63, 0.75); $secondary: $orange; $warning: $orange; $danger: $orange; diff --git a/explorer/templates/assistant/table_description_confirm_delete.html b/explorer/templates/assistant/table_description_confirm_delete.html new file mode 100644 index 00000000..87080208 --- /dev/null +++ b/explorer/templates/assistant/table_description_confirm_delete.html @@ -0,0 +1,14 @@ +{% extends "explorer/base.html" %} + +{% block sql_explorer_content %} +
+

Confirm Delete

+

Are you sure you want to delete the table description for "{{ object.table_name }}" in {{ object.connection.alias }}?

+
+ {% csrf_token %} + + Cancel +
+
+{% endblock %} + diff --git a/explorer/templates/assistant/table_description_form.html b/explorer/templates/assistant/table_description_form.html new file mode 100644 index 00000000..c3bacaad --- /dev/null +++ b/explorer/templates/assistant/table_description_form.html @@ -0,0 +1,50 @@ +{% extends "explorer/base.html" %} + +{% block sql_explorer_content %} +
+
+
+

{% if form.instance.pk %}Edit{% else %}Create{% endif %} Table Description

+ {% if form.errors %} + {% for field in form %} + {% for error in field.errors %} +
+ {{ error|escape }} +
+ {% endfor %} + {% endfor %} + {% for error in form.non_field_errors %} +
+ {{ error|escape }} +
+ {% endfor %} + {% endif %} +
+ {% csrf_token %} +
+
+ {{ form.database_connection }} + +
+
+ Table Name + {{ form.table_name }} +
+
+ {{ form.description }} + +
+
+ + Cancel +
+
+
+
+ +
+
+
+
+ +{% endblock %} diff --git a/explorer/templates/assistant/table_description_list.html b/explorer/templates/assistant/table_description_list.html new file mode 100644 index 00000000..1ae59f3d --- /dev/null +++ b/explorer/templates/assistant/table_description_list.html @@ -0,0 +1,38 @@ +{% extends "explorer/base.html" %} + +{% block sql_explorer_content %} +
+

Table Annotations

+

These will be injected into any AI assistant prompts that reference the annotated table.

+
+ Create New + + + + + + + + + + + {% for table_description in table_descriptions %} + + + + + + + {% empty %} + + + + {% endfor %} + +
ConnectionTable NameDescriptionActions
{{ table_description.database_connection }}{{ table_description.table_name }}{{ table_description.description|truncatewords:20 }} + + +
No table descriptions available.
+
+
+{% endblock %} diff --git a/explorer/templates/explorer/assistant.html b/explorer/templates/explorer/assistant.html index beb4deba..afec2aa0 100644 --- a/explorer/templates/explorer/assistant.html +++ b/explorer/templates/explorer/assistant.html @@ -9,44 +9,52 @@
+
+
+
+

+ Loading... +

+
+
-
+
"Ask Assistant" to try and automatically fix the issue. The assistant is already aware of error messages & context.
-
- - - (?) - - -
+
+
+
+
+
+
+ +
+
+ +
+
+ +
+
+
-
-
+
-
-
-
-

- Loading... -

-
-
diff --git a/explorer/templates/explorer/base.html b/explorer/templates/explorer/base.html index 5eb41780..32bac5ba 100644 --- a/explorer/templates/explorer/base.html +++ b/explorer/templates/explorer/base.html @@ -19,6 +19,7 @@ + {% if vite_dev_mode %}

@@ -63,6 +64,12 @@

This is easy to fix, I promise!

{% translate "Connections" %} + {% if assistant_enabled %} + + {% endif %} {% endif %} {% endif %}
{% if query and can_change and tasks_enabled %}{{ form.snapshot }} {% translate "Snapshot" %}{% endif %} + {% if query and can_change and assistant_enabled %}{{ form.few_shot }} {% translate "Assistant Example" %}{% endif %}
{% endblock %} diff --git a/explorer/tests/test_assistant.py b/explorer/tests/test_assistant.py index e9a9484a..d1bcde27 100644 --- a/explorer/tests/test_assistant.py +++ b/explorer/tests/test_assistant.py @@ -5,11 +5,22 @@ import json from django.test import TestCase +from django.utils import timezone from django.urls import reverse from django.contrib.auth.models import User from django.db import OperationalError from explorer.ee.db_connections.utils import default_db_connection -from explorer.assistant.utils import sample_rows_from_table, ROW_SAMPLE_SIZE, build_prompt +from explorer.ee.db_connections.models import DatabaseConnection +from explorer.assistant.utils import ( + sample_rows_from_table, + ROW_SAMPLE_SIZE, + build_prompt, + get_relevant_few_shots, + get_relevant_annotation +) + +from explorer.assistant.models import TableDescription +from explorer.models import PromptLog def conn(): @@ -31,23 +42,19 @@ def setUp(self): } @patch("explorer.assistant.utils.openai_client") - @patch("explorer.assistant.utils.num_tokens_from_string") - def test_do_modify_query(self, mocked_num_tokens, mocked_openai_client): + def test_do_modify_query(self, mocked_openai_client): from explorer.assistant.views import run_assistant # create.return_value should match: resp.choices[0].message mocked_openai_client.return_value.chat.completions.create.return_value = Mock( choices=[Mock(message=Mock(content="smart computer"))]) - mocked_num_tokens.return_value = 100 resp = run_assistant(self.request_data, None) self.assertEqual(resp, "smart computer") @patch("explorer.assistant.utils.openai_client") - @patch("explorer.assistant.utils.num_tokens_from_string") - def test_assistant_help(self, mocked_num_tokens, mocked_openai_client): + def test_assistant_help(self, mocked_openai_client): mocked_openai_client.return_value.chat.completions.create.return_value = Mock( choices=[Mock(message=Mock(content="smart computer"))]) - mocked_num_tokens.return_value = 100 resp = self.client.post(reverse("assistant"), data=json.dumps(self.request_data), content_type="application/json") @@ -57,38 +64,45 @@ def test_assistant_help(self, mocked_num_tokens, mocked_openai_client): @unittest.skipIf(not app_settings.has_assistant(), "assistant not enabled") class TestBuildPrompt(TestCase): - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=True) @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_vendor_only(self, mock_get_item, mock_fits_in_window, mock_sample_rows): + def test_build_prompt_with_vendor_only(self, mock_get_item): mock_get_item.return_value.value = "system prompt" + result = build_prompt(default_db_connection(), + "Help me with SQL", [], sql="SELECT * FROM table;") + self.assertIn("sqlite", result["system"]) - included_tables = [] + @patch("explorer.assistant.utils.sample_rows_from_table", return_value="sample data") + @patch("explorer.assistant.utils.table_schema", return_value=[]) + @patch("explorer.models.ExplorerValue.objects.get_item") + def test_build_prompt_with_sql_and_annotation(self, mock_get_item, mock_table_schema, mock_sample_rows): + mock_get_item.return_value.value = "system prompt" - result = build_prompt(default_db_connection(), "Help me with SQL", included_tables) - self.assertIn("## Database Vendor / SQL Flavor is sqlite", result["user"]) - self.assertIn("## User's Request to Assistant ##\n\nHelp me with SQL\n\n", result["user"]) - self.assertEqual(result["system"], "system prompt") + included_tables = ["foo"] + td = TableDescription(database_connection=default_db_connection(), table_name="foo", description="annotated") + td.save() - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=True) + result = build_prompt(default_db_connection(), + "Help me with SQL", included_tables, sql="SELECT * FROM table;") + self.assertIn("Usage Notes:\nannotated", result["user"]) + + @patch("explorer.assistant.utils.sample_rows_from_table", return_value="sample data") + @patch("explorer.assistant.utils.table_schema", return_value=[]) @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_sql(self, mock_get_item, mock_fits_in_window, mock_sample_rows): + def test_build_prompt_with_few_shot(self, mock_get_item, mock_table_schema, mock_sample_rows): mock_get_item.return_value.value = "system prompt" - included_tables = [] + included_tables = ["magic"] + SimpleQueryFactory(title="Few shot", description="the quick brown fox", sql="select 'magic value';", + few_shot=True) result = build_prompt(default_db_connection(), "Help me with SQL", included_tables, sql="SELECT * FROM table;") - self.assertIn("## Database Vendor / SQL Flavor is sqlite", result["user"]) - self.assertIn("## Existing SQL ##\n\nSELECT * FROM table;\n\n", result["user"]) - self.assertIn("## User's Request to Assistant ##\n\nHelp me with SQL\n\n", result["user"]) - self.assertEqual(result["system"], "system prompt") + self.assertIn("Relevant example queries", result["user"]) + self.assertIn("magic value", result["user"]) - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=True) + @patch("explorer.assistant.utils.sample_rows_from_table", return_value="sample data") @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_sql_and_error(self, mock_get_item, mock_fits_in_window, mock_sample_rows): + def test_build_prompt_with_sql_and_error(self, mock_get_item, mock_sample_rows): mock_get_item.return_value.value = "system prompt" included_tables = [] @@ -96,45 +110,22 @@ def test_build_prompt_with_sql_and_error(self, mock_get_item, mock_fits_in_windo result = build_prompt(default_db_connection(), "Help me with SQL", included_tables, "Syntax error", "SELECT * FROM table;") - self.assertIn("## Database Vendor / SQL Flavor is sqlite", result["user"]) - self.assertIn("## Existing SQL ##\n\nSELECT * FROM table;\n\n", result["user"]) - self.assertIn("## Query Error ##\n\nSyntax error\n\n", result["user"]) - self.assertIn("## User's Request to Assistant ##\n\nHelp me with SQL\n\n", result["user"]) - self.assertEqual(result["system"], "system prompt") - - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=True) - @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_extra_tables_fitting_window(self, mock_get_item, mock_fits_in_window, mock_sample_rows): - mock_get_item.return_value.value = "system prompt" + self.assertIn("## Existing User-Written SQL ##\nSELECT * FROM table;", result["user"]) + self.assertIn("## Query Error ##\nSyntax error\n", result["user"]) + self.assertIn("## User's Request to Assistant ##\nHelp me with SQL", result["user"]) + self.assertIn("system prompt", result["system"]) - included_tables = ["table1", "table2"] - - result = build_prompt(default_db_connection(), "Help me with SQL", - included_tables, sql="SELECT * FROM table;") - self.assertIn("## Database Vendor / SQL Flavor is sqlite", result["user"]) - self.assertIn("## Existing SQL ##\n\nSELECT * FROM table;\n\n", result["user"]) - self.assertIn("## Table Structure with Sampled Data ##\n\nsample data\n\n", result["user"]) - self.assertIn("## User's Request to Assistant ##\n\nHelp me with SQL\n\n", result["user"]) - self.assertEqual(result["system"], "system prompt") - - @patch("explorer.assistant.utils.sample_rows_from_tables", return_value="sample data") - @patch("explorer.assistant.utils.fits_in_window", return_value=False) - @patch("explorer.assistant.utils.tables_from_schema_info", return_value="table structure") @patch("explorer.models.ExplorerValue.objects.get_item") - def test_build_prompt_with_extra_tables_not_fitting_window(self, mock_get_item, mock_tables_from_schema_info, - mock_fits_in_window, mock_sample_rows): + def test_build_prompt_with_extra_tables_fitting_window(self, mock_get_item): mock_get_item.return_value.value = "system prompt" - included_tables = ["table1", "table2"] + included_tables = ["explorer_query"] + SimpleQueryFactory() result = build_prompt(default_db_connection(), "Help me with SQL", included_tables, sql="SELECT * FROM table;") - self.assertIn("## Database Vendor / SQL Flavor is sqlite", result["user"]) - self.assertIn("## Existing SQL ##\n\nSELECT * FROM table;\n\n", result["user"]) - self.assertIn("## Table Structure ##\n\ntable structure\n\n", result["user"]) - self.assertIn("## User's Request to Assistant ##\n\nHelp me with SQL\n\n", result["user"]) - self.assertEqual(result["system"], "system prompt") + self.assertIn("## Information for Table 'explorer_query' ##", result["user"]) + self.assertIn("Sample rows:\nid | title", result["user"]) @unittest.skipIf(not app_settings.has_assistant(), "assistant not enabled") @@ -161,10 +152,10 @@ def test_truncates_long_strings(self): header, row = ret self.assertEqual(header, ["col1", "col2"]) - self.assertEqual(row[0], "a" * 500 + "...") + self.assertEqual(row[0], "a" * 200 + "...") self.assertEqual(row[1], "short string") - def test_truncates_long_binary_data(self): + def test_binary_data(self): long_binary = b"a" * 600 # Mock database connection and cursor @@ -179,8 +170,8 @@ def test_truncates_long_binary_data(self): header, row = ret self.assertEqual(header, ["col1", "col2"]) - self.assertEqual(row[0], b"a" * 500 + b"...") - self.assertEqual(row[1], b"short binary") + self.assertEqual(row[0], "") + self.assertEqual(row[1], "") def test_handles_various_data_types(self): # Mock database connection and cursor @@ -217,24 +208,12 @@ def test_format_rows_from_table(self): ["val1", "val2"], ] ret = format_rows_from_table(d) - self.assertEqual(ret, "col1 | col2\n" + "-" * 50 + "\nval1 | val2\n") - - def test_parsing_tables_from_query(self): - from explorer.assistant.utils import get_table_names_from_query - sql = "SELECT * FROM explorer_query" - ret = get_table_names_from_query(sql) - self.assertEqual(ret, ["explorer_query"]) - - def test_parsing_tables_from_no_tables(self): - from explorer.assistant.utils import get_table_names_from_query - sql = "select 1;" - ret = get_table_names_from_query(sql) - self.assertEqual(ret, []) + self.assertEqual(ret, "col1 | col2\nval1 | val2") def test_schema_info_from_table_names(self): - from explorer.assistant.utils import tables_from_schema_info - ret = tables_from_schema_info(default_db_connection(), ["explorer_query"]) - expected = [("explorer_query", [ + from explorer.assistant.utils import table_schema + ret = table_schema(default_db_connection(), "explorer_query") + expected = [ ("id", "AutoField"), ("title", "CharField"), ("sql", "TextField"), @@ -244,26 +223,196 @@ def test_schema_info_from_table_names(self): ("created_by_user_id", "IntegerField"), ("snapshot", "BooleanField"), ("connection", "CharField"), - ("database_connection_id", "IntegerField")])] + ("database_connection_id", "IntegerField"), + ("few_shot", "BooleanField")] self.assertEqual(ret, expected) @unittest.skipIf(not app_settings.has_assistant(), "assistant not enabled") class TestAssistantUtils(TestCase): - def test_sample_rows_from_tables(self): - from explorer.assistant.utils import sample_rows_from_tables + def test_sample_rows_from_table(self): + from explorer.assistant.utils import sample_rows_from_table, format_rows_from_table SimpleQueryFactory(title="First Query") SimpleQueryFactory(title="Second Query") QueryLogFactory() - ret = sample_rows_from_tables(conn(), ["explorer_query", "explorer_querylog"]) + ret = sample_rows_from_table(conn(), "explorer_query") + self.assertEqual(len(ret), ROW_SAMPLE_SIZE) + ret = format_rows_from_table(ret) self.assertTrue("First Query" in ret) self.assertTrue("Second Query" in ret) - self.assertTrue("explorer_querylog" in ret) - def test_sample_rows_from_tables_no_tables(self): - from explorer.assistant.utils import sample_rows_from_tables + def test_sample_rows_from_tables_no_table_match(self): + from explorer.assistant.utils import sample_rows_from_table SimpleQueryFactory(title="First Query") SimpleQueryFactory(title="Second Query") - ret = sample_rows_from_tables(conn(), []) - self.assertEqual(ret, "") + ret = sample_rows_from_table(conn(), "banana") + self.assertEqual(ret, [["no such table: banana"]]) + + def test_relevant_few_shots(self): + relevant_q1 = SimpleQueryFactory(sql="select * from relevant_table", few_shot=True) + relevant_q2 = SimpleQueryFactory(sql="select * from conn.RELEVANT_TABLE limit 10", few_shot=True) + irrelevant_q2 = SimpleQueryFactory(sql="select * from conn.RELEVANT_TABLE limit 10", few_shot=False) + relevant_q3 = SimpleQueryFactory(sql="select * from conn.another_good_table limit 10", few_shot=True) + irrelevant_q1 = SimpleQueryFactory(sql="select * from irrelevant_table") + included_tables = ["relevant_table", "ANOTHER_GOOD_TABLE"] + res = get_relevant_few_shots(relevant_q1.database_connection, included_tables) + res_ids = [td.id for td in res] + self.assertIn(relevant_q1.id, res_ids) + self.assertIn(relevant_q2.id, res_ids) + self.assertIn(relevant_q3.id, res_ids) + self.assertNotIn(irrelevant_q1.id, res_ids) + self.assertNotIn(irrelevant_q2.id, res_ids) + + def test_get_relevant_annotations(self): + + relevant1 = TableDescription( + database_connection=default_db_connection(), + table_name="fruit" + ) + relevant2 = TableDescription( + database_connection=default_db_connection(), + table_name="Vegetables" + ) + irrelevant = TableDescription( + database_connection=default_db_connection(), + table_name="animals" + ) + relevant1.save() + relevant2.save() + irrelevant.save() + res1 = get_relevant_annotation(default_db_connection(), "Fruit") + self.assertEqual(relevant1.id, res1.id) + res2 = get_relevant_annotation(default_db_connection(), "vegetables") + self.assertEqual(relevant2.id, res2.id) + + +class TestAssistantHistoryApiView(TestCase): + + def setUp(self): + self.user = User.objects.create_superuser( + "admin", "admin@admin.com", "pwd" + ) + self.client.login(username="admin", password="pwd") + + def test_assistant_history_api_view(self): + # Create some PromptLogs + connection = default_db_connection() + PromptLog.objects.create( + run_by_user=self.user, + database_connection=connection, + user_request="Test request 1", + response="Test response 1", + run_at=timezone.now() + ) + PromptLog.objects.create( + run_by_user=self.user, + database_connection=connection, + user_request="Test request 2", + response="Test response 2", + run_at=timezone.now() + ) + + # Make a POST request to the API + url = reverse("assistant_history") + data = { + "connection_id": connection.id + } + response = self.client.post(url, data=json.dumps(data), content_type="application/json") + + # Check the response + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.content) + self.assertIn("logs", response_data) + self.assertEqual(len(response_data["logs"]), 2) + self.assertEqual(response_data["logs"][0]["user_request"], "Test request 2") + self.assertEqual(response_data["logs"][0]["response"], "Test response 2") + self.assertEqual(response_data["logs"][1]["user_request"], "Test request 1") + self.assertEqual(response_data["logs"][1]["response"], "Test response 1") + + def test_assistant_history_api_view_invalid_json(self): + url = reverse("assistant_history") + response = self.client.post(url, data="invalid json", content_type="application/json") + self.assertEqual(response.status_code, 400) + response_data = json.loads(response.content) + self.assertEqual(response_data["status"], "error") + self.assertEqual(response_data["message"], "Invalid JSON") + + def test_assistant_history_api_view_no_logs(self): + connection = default_db_connection() + url = reverse("assistant_history") + data = { + "connection_id": connection.id + } + response = self.client.post(url, data=json.dumps(data), content_type="application/json") + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.content) + self.assertIn("logs", response_data) + self.assertEqual(len(response_data["logs"]), 0) + + def test_assistant_history_api_view_filtered_results(self): + # Create two users + user1 = self.user + user2 = User.objects.create_superuser( + "admin2", "admin2@admin.com", "pwd" + ) + + # Create two database connections + connection1 = default_db_connection() + connection2 = DatabaseConnection.objects.create( + alias="test_connection", + engine="django.db.backends.sqlite3", + name=":memory:" + ) + + # Create prompt logs for both users and connections + PromptLog.objects.create( + run_by_user=user1, + database_connection=connection1, + user_request="User1 Connection1 request", + response="User1 Connection1 response", + run_at=timezone.now() + ) + PromptLog.objects.create( + run_by_user=user1, + database_connection=connection2, + user_request="User1 Connection2 request", + response="User1 Connection2 response", + run_at=timezone.now() + ) + PromptLog.objects.create( + run_by_user=user2, + database_connection=connection1, + user_request="User2 Connection1 request", + response="User2 Connection1 response", + run_at=timezone.now() + ) + + # Make a POST request to the API as user1 + url = reverse("assistant_history") + data = { + "connection_id": connection1.id + } + response = self.client.post(url, data=json.dumps(data), content_type="application/json") + + # Check the response + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.content) + self.assertIn("logs", response_data) + self.assertEqual(len(response_data["logs"]), 1) + self.assertEqual(response_data["logs"][0]["user_request"], "User1 Connection1 request") + self.assertEqual(response_data["logs"][0]["response"], "User1 Connection1 response") + + # Now test with user2 + self.client.logout() + self.client.login(username="admin2", password="pwd") + + response = self.client.post(url, data=json.dumps(data), content_type="application/json") + + # Check the response + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.content) + self.assertIn("logs", response_data) + self.assertEqual(len(response_data["logs"]), 1) + self.assertEqual(response_data["logs"][0]["user_request"], "User2 Connection1 request") + self.assertEqual(response_data["logs"][0]["response"], "User2 Connection1 response") diff --git a/explorer/tests/test_models.py b/explorer/tests/test_models.py index 05588f0e..915d3168 100644 --- a/explorer/tests/test_models.py +++ b/explorer/tests/test_models.py @@ -257,6 +257,80 @@ def test_local_name_calls_user_dbs_local_dir(self, mock_getcwd, mock_exists, moc # Ensure os.makedirs was called once since the directory does not exist mock_makedirs.assert_called_once_with("/mocked/path/user_dbs") + @patch("explorer.utils.get_s3_bucket") + @patch("explorer.ee.db_connections.models.cache") + def test_single_download_triggered(self, mock_cache, mock_get_s3_bucket): + # Setup mocks + mock_cache.add.return_value = True # Simulate acquiring the lock + mock_s3 = MagicMock() + mock_get_s3_bucket.return_value = mock_s3 + + # Call the method + instance = DatabaseConnection( + alias="test", + engine=DatabaseConnection.SQLITE, + name="test_db.sqlite3", + host="some-s3-bucket", + id=123 + ) + instance.download_sqlite_if_needed() + + # Assertions + mock_s3.download_file.assert_called_once() + mock_cache.add.assert_called_once() + mock_cache.delete.assert_called_once() + + @patch("explorer.utils.get_s3_bucket") + @patch("explorer.ee.db_connections.models.cache") + def test_skip_download_when_locked(self, mock_cache, mock_get_s3_bucket): + # Setup mocks + mock_cache.add.return_value = False # Simulate that another process has the lock + mock_s3 = MagicMock() + mock_get_s3_bucket.return_value = mock_s3 + + # Call the method + instance = DatabaseConnection( + alias="test", + engine=DatabaseConnection.SQLITE, + name="test_db.sqlite3", + host="some-s3-bucket", + id=123 + ) + instance.download_sqlite_if_needed() + + # Assertions + mock_s3.download_file.assert_not_called() + mock_cache.add.assert_called_once() + mock_cache.delete.assert_not_called() + + @patch("explorer.utils.get_s3_bucket") + def test_not_downloaded_if_file_exists_and_model_is_unsaved(self, mock_get_s3_bucket): + + # Note this is NOT being saved to disk, e.g. how a DatabaseValidate would work. + connection = DatabaseConnection( + alias="test", + engine=DatabaseConnection.SQLITE, + name="test_db.sqlite3", + host="some-s3-bucket", + ) + + def mock_download_file(path, filename): pass + mock_s3 = mock_get_s3_bucket.return_value + mock_s3.download_file = MagicMock(side_effect=mock_download_file) + + # write the file + with open(connection.local_name, "w") as f: + f.write("Initial content") + + # See if it downloads + connection.download_sqlite_if_needed() + + # And it shouldn't.... + mock_s3.download_file.assert_not_called() + + # ...even though the fingerprints don't match + self.assertIsNone(connection.upload_fingerprint) + @patch("explorer.utils.get_s3_bucket") def test_fingerprint_is_updated_after_download_and_download_is_not_called_again(self, mock_get_s3_bucket): # Setup diff --git a/explorer/tests/test_views.py b/explorer/tests/test_views.py index 883e1528..20299ca7 100644 --- a/explorer/tests/test_views.py +++ b/explorer/tests/test_views.py @@ -23,6 +23,7 @@ from explorer.utils import user_can_see_query from explorer.ee.db_connections.utils import default_db_connection from explorer.schema import connection_schema_cache_key, connection_schema_json_cache_key +from explorer.assistant.models import TableDescription def reload_app_settings(): @@ -1071,12 +1072,6 @@ def test_validate_connection_invalid_form(self): self.assertEqual(response.status_code, 200) self.assertJSONEqual(response.content, {"success": False, "error": "Invalid form data"}) - def test_validate_connection_success_alt_url(self): - url = reverse("explorer_connection_validate_with_pk", args=[1]) - response = self.client.post(url, data=self.valid_data) - self.assertEqual(response.status_code, 200) - self.assertJSONEqual(response.content, {"success": True}) - def test_update_existing_connection(self): DatabaseConnection.objects.create(alias="test_alias", engine="django.db.backends.sqlite3", name=":memory:") response = self.client.post(self.url, data=self.valid_data) @@ -1179,3 +1174,15 @@ def test_database_connection_delete_view(self): def test_database_connection_upload_view(self): response = self.client.get(reverse("explorer_upload_create")) self.assertEqual(response.status_code, 200) + + def test_table_description_list_view(self): + td = TableDescription(database_connection=default_db_connection(), table_name="foo", description="annotated") + td.save() + response = self.client.get(reverse("table_description_list")) + self.assertEqual(response.status_code, 200) + + response = self.client.get(reverse("table_description_update", args=[td.pk])) + self.assertEqual(response.status_code, 200) + + response = self.client.get(reverse("table_description_create")) + self.assertEqual(response.status_code, 200) diff --git a/explorer/urls.py b/explorer/urls.py index 36fec7bb..c612ba5c 100644 --- a/explorer/urls.py +++ b/explorer/urls.py @@ -6,7 +6,7 @@ ListQueryView, PlayQueryView, QueryFavoritesView, QueryFavoriteView, QueryView, SchemaJsonView, SchemaView, StreamQueryView, format_sql ) -from explorer.assistant.views import AssistantHelpView +from explorer.assistant.urls import assistant_urls urlpatterns = [ path( @@ -43,7 +43,7 @@ path("favorites/", QueryFavoritesView.as_view(), name="query_favorites"), path("favorite/", QueryFavoriteView.as_view(), name="query_favorite"), path("", ListQueryView.as_view(), name="explorer_index"), - path("assistant/", AssistantHelpView.as_view(), name="assistant"), ] +urlpatterns += assistant_urls urlpatterns += ee_urls diff --git a/package-lock.json b/package-lock.json index 1cbcab11..03b0fd4c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11,6 +11,7 @@ "@codemirror/language-data": "^6.3.1", "bootstrap": "^5.0.1", "bootstrap-icons": "^1.11.2", + "choices.js": "^10.2.0", "codemirror": "^6.0.1", "cookiejs": "^2.1.3", "dompurify": "^3.0.7", @@ -24,6 +25,20 @@ "vite": "^5.0.13", "vite-plugin-copy": "^0.1.6", "vite-plugin-static-copy": "^1.0.5" + }, + "optionalDependencies": { + "@rollup/rollup-linux-x64-gnu": "^4.18.1" + } + }, + "node_modules/@babel/runtime": { + "version": "7.25.0", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.25.0.tgz", + "integrity": "sha512-7dRy4DwXwtzBrPbZflqxnvfxLF8kdZXPkhymtDeFoFqE6ldzjQFgYTtYIFARcLEYDrqfBfYcZt1WqFxRoyC9Rw==", + "dependencies": { + "regenerator-runtime": "^0.14.0" + }, + "engines": { + "node": ">=6.9.0" } }, "node_modules/@codemirror/autocomplete": { @@ -981,13 +996,12 @@ ] }, "node_modules/@rollup/rollup-linux-x64-gnu": { - "version": "4.9.1", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.9.1.tgz", - "integrity": "sha512-kr8rEPQ6ns/Lmr/hiw8sEVj9aa07gh1/tQF2Y5HrNCCEPiCBGnBUt9tVusrcBBiJfIt1yNaXN6r1CCmpbFEDpg==", + "version": "4.21.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.21.0.tgz", + "integrity": "sha512-e2hrvElFIh6kW/UNBQK/kzqMNY5mO+67YtEh9OA65RM5IJXYTWiXjX6fjIiPaqOkBthYF1EqgiZ6OXKcQsM0hg==", "cpu": [ "x64" ], - "dev": true, "optional": true, "os": [ "linux" @@ -1112,6 +1126,16 @@ "node": ">=8" } }, + "node_modules/choices.js": { + "version": "10.2.0", + "resolved": "https://registry.npmjs.org/choices.js/-/choices.js-10.2.0.tgz", + "integrity": "sha512-8PKy6wq7BMjNwDTZwr3+Zry6G2+opJaAJDDA/j3yxvqSCnvkKe7ZIFfIyOhoc7htIWFhsfzF9tJpGUATcpUtPg==", + "dependencies": { + "deepmerge": "^4.2.2", + "fuse.js": "^6.6.2", + "redux": "^4.2.0" + } + }, "node_modules/chokidar": { "version": "3.5.3", "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", @@ -1166,6 +1190,14 @@ "resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz", "integrity": "sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g==" }, + "node_modules/deepmerge": { + "version": "4.3.1", + "resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-4.3.1.tgz", + "integrity": "sha512-3sUqbMEc77XqpdNO7FRyRog+eW3ph+GYCbj+rK+uYyRMuwsVy0rMiVtPn+QJlKFvWP/1PYpapqYn0Me2knFn+A==", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/dompurify": { "version": "3.0.7", "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.0.7.tgz", @@ -1274,6 +1306,14 @@ "node": "^8.16.0 || ^10.6.0 || >=11.0.0" } }, + "node_modules/fuse.js": { + "version": "6.6.2", + "resolved": "https://registry.npmjs.org/fuse.js/-/fuse.js-6.6.2.tgz", + "integrity": "sha512-cJaJkxCCxC8qIIcPBF9yGxY0W/tVZS3uEISDxhYIdtk8OL93pe+6Zj7LjCqVV4dzbqcriOZ+kQ/NE4RXZHsIGA==", + "engines": { + "node": ">=10" + } + }, "node_modules/glob-parent": { "version": "5.1.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", @@ -1506,6 +1546,19 @@ "node": ">=8.10.0" } }, + "node_modules/redux": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/redux/-/redux-4.2.1.tgz", + "integrity": "sha512-LAUYz4lc+Do8/g7aeRa8JkyDErK6ekstQaqWQrNRW//MY1TvCEpMtpTWvlQ+FPbWCx+Xixu/6SHt5N0HR+SB4w==", + "dependencies": { + "@babel/runtime": "^7.9.2" + } + }, + "node_modules/regenerator-runtime": { + "version": "0.14.1", + "resolved": "https://registry.npmjs.org/regenerator-runtime/-/regenerator-runtime-0.14.1.tgz", + "integrity": "sha512-dYnhHh0nJoMfnkZs6GmmhFknAGRrLznOu5nc9ML+EJxGvrx6H7teuevqVqCuPcPK//3eDrrjQhehXVx9cnkGdw==" + }, "node_modules/reusify": { "version": "1.0.4", "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz", @@ -1545,6 +1598,19 @@ "fsevents": "~2.3.2" } }, + "node_modules/rollup/node_modules/@rollup/rollup-linux-x64-gnu": { + "version": "4.9.1", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.9.1.tgz", + "integrity": "sha512-kr8rEPQ6ns/Lmr/hiw8sEVj9aa07gh1/tQF2Y5HrNCCEPiCBGnBUt9tVusrcBBiJfIt1yNaXN6r1CCmpbFEDpg==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ] + }, "node_modules/run-parallel": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", diff --git a/package.json b/package.json index c7e89cad..53833c52 100644 --- a/package.json +++ b/package.json @@ -28,6 +28,7 @@ "@codemirror/language-data": "^6.3.1", "bootstrap": "^5.0.1", "bootstrap-icons": "^1.11.2", + "choices.js": "^10.2.0", "codemirror": "^6.0.1", "cookiejs": "^2.1.3", "dompurify": "^3.0.7", diff --git a/requirements/extra/assistant.txt b/requirements/extra/assistant.txt index c5ae0910..44fbead2 100644 --- a/requirements/extra/assistant.txt +++ b/requirements/extra/assistant.txt @@ -1,3 +1 @@ openai>=1.6.1 -sql_metadata>=2.10 -tiktoken>=0.7 diff --git a/test_project/settings.py b/test_project/settings.py index f5ec93f8..92796fc0 100644 --- a/test_project/settings.py +++ b/test_project/settings.py @@ -108,3 +108,4 @@ EXPLORER_DB_CONNECTIONS_ENABLED = True EXPLORER_USER_UPLOADS_ENABLED = True EXPLORER_CHARTS_ENABLED = True +EXPLORER_ASSISTANT_MODEL_NAME = "anthropic/claude-3.5-sonnet"