diff --git a/explorer/ee/db_connections/create_sqlite.py b/explorer/ee/db_connections/create_sqlite.py index 4a25f723..1aa2cdb8 100644 --- a/explorer/ee/db_connections/create_sqlite.py +++ b/explorer/ee/db_connections/create_sqlite.py @@ -1,33 +1,37 @@ import os from io import BytesIO -from explorer.ee.db_connections.type_infer import get_parser, is_sqlite -from explorer.ee.db_connections.utils import pandas_to_sqlite, download_local_sqlite +from explorer.ee.db_connections.type_infer import get_parser +from explorer.ee.db_connections.utils import pandas_to_sqlite, download_sqlite_if_needed, uploaded_db_local_path -def parse_to_sqlite(file, append=None) -> (BytesIO, str): - f_name = file.name +def parse_to_sqlite(file, append=None, user_id=None) -> (BytesIO, str): + f_bytes = file.read() + table_name, _ = os.path.splitext(file.name) + + # If this is being uploaded as a new connection, use the file name the upload destination. + # If it's being appended, then we should use the same local path as the file being appended to. + if append: + f_name = os.path.basename(append) + else: + f_name = f"{table_name}_{user_id}.db" - local_path = f"{f_name}_tmp_local.db" + local_path = uploaded_db_local_path(f_name) + # When appending, make sure the database exists locally so that we can write to it if append: - if is_sqlite(file): - raise TypeError("Can't append a SQLite file to a SQLite file. Only CSV and JSON.") - # Get the sqlite file we are going to append to into the local filesystem - download_local_sqlite(append, local_path) + download_sqlite_if_needed(append, local_path) df_parser = get_parser(file) if df_parser: df = df_parser(f_bytes) - name, _ = os.path.splitext(f_name) try: - f_bytes = pandas_to_sqlite(df, name, local_path, append) + f_bytes = pandas_to_sqlite(df, table_name, local_path, append) except Exception as e: # noqa raise ValueError(f"Error while parsing {f_name}: {e}") from e - # replace the previous extension with .db, as it is now a sqlite file - f_name = f"{name}.db" else: - return BytesIO(f_bytes), f_name # if it's a SQLite file already, simply cough it up as a BytesIO object + # if it's a SQLite file already, simply cough it up as a BytesIO object + return BytesIO(f_bytes), f_name return f_bytes, f_name diff --git a/explorer/ee/db_connections/models.py b/explorer/ee/db_connections/models.py index 2b0586dd..06e457e2 100644 --- a/explorer/ee/db_connections/models.py +++ b/explorer/ee/db_connections/models.py @@ -1,11 +1,10 @@ -import os from django.conf import settings from django.core.exceptions import ValidationError from django.db import models from django.db.models.signals import pre_save from django.dispatch import receiver -from explorer.ee.db_connections.utils import user_dbs_local_dir +from explorer.ee.db_connections.utils import uploaded_db_local_path from django_cryptography.fields import encrypt @@ -44,7 +43,7 @@ def is_upload(self): @property def local_name(self): if self.is_upload: - return os.path.join(user_dbs_local_dir(), self.name) + return uploaded_db_local_path(self.name) @classmethod def from_django_connection(cls, connection_alias): diff --git a/explorer/ee/db_connections/utils.py b/explorer/ee/db_connections/utils.py index 028f58fa..46626424 100644 --- a/explorer/ee/db_connections/utils.py +++ b/explorer/ee/db_connections/utils.py @@ -21,12 +21,10 @@ def upload_sqlite(db_bytes, path): # to this new database connection. Oops! # TODO: In the future, queries should probably be FK'ed to the ID of the connection, rather than simply # storing the alias of the connection as a string. -def create_connection_for_uploaded_sqlite(filename, user_id, s3_path): +def create_connection_for_uploaded_sqlite(filename, s3_path): from explorer.models import DatabaseConnection - base, ext = os.path.splitext(filename) - filename = f"{base}_{user_id}{ext}" return DatabaseConnection.objects.create( - alias=f"{filename}", + alias=filename, engine=DatabaseConnection.SQLITE, name=filename, host=s3_path @@ -37,14 +35,14 @@ def get_sqlite_for_connection(explorer_connection): # Get the database from s3, then modify the connection to work with the downloaded file. # E.g. "host" should not be set, and we need to get the full path to the file local_name = explorer_connection.local_name - download_local_sqlite(explorer_connection.host, local_name) + download_sqlite_if_needed(explorer_connection.host, local_name) explorer_connection.host = None explorer_connection.name = local_name return explorer_connection -def download_local_sqlite(s3_path, local_path): +def download_sqlite_if_needed(s3_path, local_path): from explorer.utils import get_s3_bucket if not os.path.exists(local_path): s3 = get_s3_bucket() @@ -58,6 +56,10 @@ def user_dbs_local_dir(): return d +def uploaded_db_local_path(name): + return os.path.join(user_dbs_local_dir(), name) + + def create_django_style_connection(explorer_connection): if explorer_connection.is_upload: @@ -92,24 +94,12 @@ def create_django_style_connection(explorer_connection): def sqlite_to_bytesio(local_path): - try: - db_file = io.BytesIO() - with open(local_path, "rb") as f: - db_file.write(f.read()) - db_file.seek(0) - return db_file - finally: - # Delete the local SQLite database file - # Finally block to ensure we don't litter files around - os.remove(local_path) - - -def drop_table_if_exists(connection, table_name): - cursor = connection.cursor() - cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';") - table_exists = cursor.fetchone() is not None - if table_exists: - cursor.execute(f"DROP TABLE {table_name};") + # Write the file to disk. It'll be uploaded to s3, and left here locally for querying + db_file = io.BytesIO() + with open(local_path, "rb") as f: + db_file.write(f.read()) + db_file.seek(0) + return db_file def pandas_to_sqlite(df, table_name, local_path, append): @@ -119,9 +109,6 @@ def pandas_to_sqlite(df, table_name, local_path, append): # Also, potentially (if we are appending to an existing SQLite DB) then there is actually data in the local file. conn = sqlite3.connect(local_path) - if append: - drop_table_if_exists(conn, table_name) - try: df.to_sql(table_name, conn, if_exists="replace", index=False) finally: diff --git a/explorer/ee/db_connections/views.py b/explorer/ee/db_connections/views.py index bab41263..b0a4f685 100644 --- a/explorer/ee/db_connections/views.py +++ b/explorer/ee/db_connections/views.py @@ -17,6 +17,7 @@ from explorer.views.auth import PermissionRequiredMixin from explorer.views.mixins import ExplorerContextMixin from explorer.ee.db_connections.utils import create_django_style_connection +from explorer.ee.db_connections.mime import is_sqlite logger = logging.getLogger(__name__) @@ -30,15 +31,20 @@ def post(self, request): # noqa file = request.FILES.get("file") if file: + # 'append' should be None, or the s3 path of the sqlite DB to append this table to. + # This is stored in DatabaseConnection.host of the previously uploaded connection append = request.POST.get("append") if file.size > EXPLORER_MAX_UPLOAD_SIZE: friendly = EXPLORER_MAX_UPLOAD_SIZE / (1024 * 1024) return JsonResponse({"error": f"File size exceeds the limit of {friendly} MB"}, status=400) + # You can't double stramp a triple stamp! + if append and is_sqlite(file): + raise TypeError("Can't append a SQLite file to a SQLite file. Only CSV and JSON.") + try: - # The 'append' should be the s3 path of the sqlite DB to append this table to - f_bytes, f_name = parse_to_sqlite(file, append) + f_bytes, f_name = parse_to_sqlite(file, append, request.user.id) except ValueError as e: logger.error(f"Error getting bytes for {file.name}: {e}") return JsonResponse({"error": "File was not csv, json, or sqlite."}, status=400) @@ -59,7 +65,7 @@ def post(self, request): # noqa # If we are appending to an existing sqlite source, don't create a new DB connection record if not append: - create_connection_for_uploaded_sqlite(f_name, request.user.id, s3_path) + create_connection_for_uploaded_sqlite(f_name, s3_path) return JsonResponse({"success": True}) else: return JsonResponse({"error": "No file provided"}, status=400) diff --git a/explorer/tests/test_views.py b/explorer/tests/test_views.py index d05eead4..c306acfe 100644 --- a/explorer/tests/test_views.py +++ b/explorer/tests/test_views.py @@ -933,11 +933,10 @@ def test_post_csv_file(self): # An end-to-end test that uploads a json file, verifies a connection was created, then issues a query # using that connection and verifies the right data is returned. @patch("explorer.ee.db_connections.views.upload_sqlite") - @patch("explorer.ee.db_connections.create_sqlite.download_local_sqlite") - def test_upload_file(self, mocked_download_sqlite, mock_upload_sqlite): + def test_upload_file(self, mock_upload_sqlite): self.assertFalse(DatabaseConnection.objects.filter(alias__contains="kings").exists()) - # Test data file + # Upload some JSON file_path = os.path.join(os.getcwd(), "explorer/tests/json/kings.json") with open(file_path, "rb") as f: response = self.client.post(reverse("explorer_upload"), {"file": f}) @@ -946,48 +945,31 @@ def test_upload_file(self, mocked_download_sqlite, mock_upload_sqlite): self.assertEqual(response.status_code, 200) self.assertEqual(mock_upload_sqlite.call_count, 1) - # Now write the SQLite bytes locally, to the newly-created connection's local path - # We are going query this new data source, and writing the bytes here preempts the system's attempt to download - # it from S3 since the file already exists on disk. No need to mock get_sqlite_for_connection! + # Query it and make sure that the reign of this particular king is indeed in the results. conn = DatabaseConnection.objects.filter(alias__contains="kings").first() - os.makedirs(os.path.dirname(conn.local_name), exist_ok=True) - with open(conn.local_name, "wb") as temp_file: - temp_file.write(mock_upload_sqlite.call_args[0][0].getvalue()) - resp = self.client.post( reverse("explorer_playground"), {"sql": "select * from kings where Name = 'Athelstan';", "connection": conn.alias} ) - - # Assert that the reign of this particular king is indeed in the results. self.assertIn("925-940", resp.content.decode("utf-8")) - # Append a new table - - def mocked_download(s3_path, local_path): - with open(local_path, "wb") as temp_file_append: - temp_file_append.write(mock_upload_sqlite.call_args[0][0].getvalue()) - mocked_download_sqlite.side_effect = mocked_download - + # Append a new table to the existing connection file_path = os.path.join(os.getcwd(), "explorer/tests/csvs/rc_sample.csv") with open(file_path, "rb") as f: - # append param doesn't matter because we just write the previoud DB to disk. - # normally append would specify what DB to retrieve from s3 to get appended to. - response = self.client.post(reverse("explorer_upload"), {"file": f, "append": "remote_s3_source.db"}) + response = self.client.post(reverse("explorer_upload"), {"file": f, "append": conn.host}) + + # Make sure it got re-uploaded self.assertEqual(response.status_code, 200) - self.assertEqual(mocked_download_sqlite.call_count, 1) - # The DB that has been appended to gets re-uploaded self.assertEqual(mock_upload_sqlite.call_count, 2) - # So write it back to disk so that we can query it - with open(conn.local_name, "wb") as temp_file: - temp_file.write(mock_upload_sqlite.call_args[0][0].getvalue()) + # Query it and make sure a valid result is in the response. Note this is the *same* connection. resp = self.client.post( reverse("explorer_playground"), {"sql": "select * from rc_sample where material_type = 'Steel';", "connection": conn.alias} ) self.assertIn("Goudurix", resp.content.decode("utf-8")) + # Clean up filesystem os.remove(conn.local_name) def test_post_no_file(self):