From eb7b00fe75e08b313df6a9d618ec58c3ba47f581 Mon Sep 17 00:00:00 2001 From: Sparrow0hawk Date: Sun, 18 Feb 2024 20:30:08 +0000 Subject: [PATCH] Bring up URI check for sqlite engine Includes a test to check that a database file is not created when using sqlite URI for in memory. --- src/flask_sqlalchemy/extension.py | 18 +++++++++--------- tests/test_engine.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 43e1b9a4..dba41e1a 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -607,7 +607,15 @@ def _apply_driver_defaults(self, options: dict[str, t.Any], app: Flask) -> None: url = sa.engine.make_url(options["url"]) if url.drivername in {"sqlite", "sqlite+pysqlite"}: - if url.database is None or url.database in {"", ":memory:"}: + # the url might look like sqlite:///file:path?uri=true + is_uri = url.query.get("uri", False) + + if is_uri and url.database: + db_str: t.Optional[str] = url.database[5:] + else: + db_str = url.database + + if db_str is None or db_str in {"", ":memory:"}: options["poolclass"] = sa.pool.StaticPool if "connect_args" not in options: @@ -615,14 +623,6 @@ def _apply_driver_defaults(self, options: dict[str, t.Any], app: Flask) -> None: options["connect_args"]["check_same_thread"] = False else: - # the url might look like sqlite:///file:path?uri=true - is_uri = url.query.get("uri", False) - - if is_uri: - db_str = url.database[5:] - else: - db_str = url.database - if not os.path.isabs(db_str): os.makedirs(app.instance_path, exist_ok=True) db_str = os.path.join(app.instance_path, db_str) diff --git a/tests/test_engine.py b/tests/test_engine.py index 0e88d5e3..e59ddfab 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -110,6 +110,16 @@ def test_sqlite_driver_level_uri(app: Flask, model_class: t.Any) -> None: assert os.path.exists(db_path[5:]) +@pytest.mark.usefixtures("app_ctx") +def test_sqlite_driver_level_uri_in_memory(app: Flask, model_class: t.Any) -> None: + app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///file::memory:?uri=true" + db = SQLAlchemy(app, model_class=model_class) + db.create_all() + db_path = db.engine.url.database + assert db_path is not None + assert not os.path.exists(db_path[5:]) + + @unittest.mock.patch.object(SQLAlchemy, "_make_engine", autospec=True) def test_sqlite_memory_defaults( make_engine: unittest.mock.Mock, app: Flask, model_class: t.Any