Skip to content

Commit

Permalink
default connection first
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisclark committed Aug 15, 2024
1 parent a2cce97 commit 25543c3
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 15 deletions.
2 changes: 1 addition & 1 deletion explorer/assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def run_assistant(request_data, user):
extra_tables = request_data.get("selected_tables", [])
included_tables = get_table_names_from_query(sql) + extra_tables

connection_id = request_data.get("connection")
connection_id = request_data.get("connection_id")
try:
conn = DatabaseConnection.objects.get(id=connection_id)
except DatabaseConnection.DoesNotExist:
Expand Down
12 changes: 11 additions & 1 deletion explorer/forms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from django.forms import BooleanField, CharField, ModelForm, ValidationError
from django.forms.widgets import CheckboxInput, Select
from django.db.models import Value, IntegerField, When, Case

from explorer.app_settings import EXPLORER_DEFAULT_CONNECTION
from explorer.models import MSG_FAILED_BLACKLIST, Query
from explorer.ee.db_connections.models import DatabaseConnection
from explorer.ee.db_connections.utils import default_db_connection_id


class SqlField(CharField):
Expand Down Expand Up @@ -64,7 +66,15 @@ def created_at_time(self):

@property
def connections(self):
return [(c.id, c.alias) for c in DatabaseConnection.objects.all()]
# Ensure the default connection appears first in the dropdown in the form
result = DatabaseConnection.objects.annotate(
custom_order=Case(
When(id=default_db_connection_id(), then=Value(0)),
default=Value(1),
output_field=IntegerField(),
)
).order_by("custom_order", "id")
return [(c.id, c.alias) for c in result.all()]

class Meta:
model = Query
Expand Down
2 changes: 1 addition & 1 deletion explorer/src/js/assistant.js
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ function submitAssistantAsk() {

const data = {
sql: window.editor?.state.doc.toString() ?? null,
connection: document.getElementById("id_database_connection")?.value ?? null,
connection_id: document.getElementById("id_database_connection")?.value ?? null,
assistant_request: document.getElementById("id_assistant_input")?.value ?? null,
selected_tables: selectedTables,
db_error: getErrorMessage()
Expand Down
2 changes: 1 addition & 1 deletion explorer/tests/test_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def setUp(self):
self.client.login(username="admin", password="pwd")
self.request_data = {
"sql": "SELECT * FROM explorer_query",
"connection": 1,
"connection_id": 1,
"assistant_request": "Test Request"
}

Expand Down
9 changes: 9 additions & 0 deletions explorer/tests/test_forms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from django.db.utils import IntegrityError
from django.forms.models import model_to_dict
from django.test import TestCase
from unittest.mock import patch

from explorer.forms import QueryForm
from explorer.tests.factories import SimpleQueryFactory
Expand Down Expand Up @@ -50,3 +51,11 @@ def test_valid_form_submission(self):
self.assertEqual(query.database_connection_id, default_db_connection_id())
self.assertEqual(query.title, form_data["title"])
self.assertEqual(query.sql, form_data["sql"])

@patch("explorer.forms.default_db_connection_id")
def test_default_connection_first(self, mocked_default_db_connection_id):
mocked_default_db_connection_id.return_value = default_db_connection_id()
self.assertEqual(default_db_connection_id(), QueryForm().connections[0][0])

mocked_default_db_connection_id.return_value = 2
self.assertEqual(2, QueryForm().connections[0][0])
23 changes: 14 additions & 9 deletions explorer/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from django.urls import reverse

from explorer import app_settings
from explorer.forms import QueryForm
from explorer.app_settings import EXPLORER_DEFAULT_CONNECTION as CONN
from explorer.app_settings import EXPLORER_TOKEN, EXPLORER_USER_UPLOADS_ENABLED
from explorer.models import MSG_FAILED_BLACKLIST, Query, QueryFavorite, QueryLog, DatabaseConnection
Expand Down Expand Up @@ -598,18 +599,22 @@ def test_sql_download_respects_connection(self):

url = reverse("download_sql") + "?format=csv"

response = self.client.post(
url,
{"sql": "select * from animals;", "connection": DatabaseConnection.objects.get(alias=c2_alias).id}
)
form_data = {"sql": "select * from animals;",
"title": "foo",
"database_connection": DatabaseConnection.objects.get(alias=c2_alias).id}
form = QueryForm(data=form_data)
self.assertTrue(form.is_valid())
response = self.client.post(url, form.data)

self.assertEqual(response.status_code, 200)
self.assertContains(response, "superchicken")

def test_sql_download_csv_with_custom_delim(self):
url = reverse("download_sql") + "?format=csv&delim=|"

response = self.client.post(url, {"sql": "select 1,2;"})
form_data = {"sql": "select 1,2;", "title": "foo", "database_connection": default_db_connection().id}
form = QueryForm(data=form_data)
self.assertTrue(form.is_valid())
response = self.client.post(url, form.data)

self.assertEqual(response.status_code, 200)
self.assertEqual(response["content-type"], "text/csv")
Expand Down Expand Up @@ -674,7 +679,7 @@ def test_returns_404_if_conn_doesnt_exist(self):
def test_admin_required(self):
self.client.logout()
resp = self.client.get(
reverse("explorer_schema", kwargs={"connection": default_db_connection().alias})
reverse("explorer_schema", kwargs={"connection": default_db_connection().id})
)
self.assertTemplateUsed(resp, "admin/login.html")

Expand Down Expand Up @@ -941,7 +946,7 @@ def test_upload_file(self, mock_upload_sqlite):
conn = DatabaseConnection.objects.filter(alias__contains="kings").first()
resp = self.client.post(
reverse("explorer_playground"),
{"sql": "select * from kings where Name = 'Athelstan';", "connection": conn.id}
{"sql": "select * from kings where Name = 'Athelstan';", "database_connection": conn.id}
)
self.assertIn("925-940", resp.content.decode("utf-8"))

Expand All @@ -957,7 +962,7 @@ def test_upload_file(self, mock_upload_sqlite):
# Query it and make sure a valid result is in the response. Note this is the *same* connection.
resp = self.client.post(
reverse("explorer_playground"),
{"sql": "select * from rc_sample where material_type = 'Steel';", "connection": conn.id}
{"sql": "select * from rc_sample where material_type = 'Steel';", "database_connection": conn.id}
)
self.assertIn("Goudurix", resp.content.decode("utf-8"))

Expand Down
2 changes: 1 addition & 1 deletion explorer/views/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DownloadFromSqlView(PermissionRequiredMixin, View):

def post(self, request, *args, **kwargs):
sql = request.POST.get("sql", "")
connection = request.POST.get("connection", default_db_connection_id())
connection = request.POST.get("database_connection", default_db_connection_id())
query = Query(sql=sql, database_connection_id=connection, title="")
ql = query.log(request.user)
query.title = f"Playground-{ql.id}"
Expand Down
2 changes: 1 addition & 1 deletion explorer/views/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get(self, request):
return self.render()

def post(self, request):
c = request.POST.get("connection", default_db_connection_id())
c = request.POST.get("database_connection", default_db_connection_id())
show = url_get_show(request)
sql = request.POST.get("sql", "")

Expand Down

0 comments on commit 25543c3

Please sign in to comment.