@@ -43,6 +43,7 @@ def __init__(self, url, **kwargs):
4343 import os
4444 import re
4545 import sqlalchemy
46+ import sqlalchemy .orm
4647 import sqlite3
4748
4849 # Require that file already exist for SQLite
@@ -72,12 +73,12 @@ def connect(dbapi_connection, connection_record):
7273 cursor .execute ("PRAGMA foreign_keys=ON" )
7374 cursor .close ()
7475
75- # Autocommit by default
76- self ._autocommit = True
77-
7876 # Register listener
7977 sqlalchemy .event .listen (self ._engine , "connect" , connect )
8078
79+ # Autocommit by default
80+ self ._autocommit = True
81+
8182 # Test database
8283 disabled = self ._logger .disabled
8384 self ._logger .disabled = True
@@ -96,9 +97,9 @@ def __del__(self):
9697
9798 def _disconnect (self ):
9899 """Close database connection."""
99- if hasattr (self , "_connection " ):
100- self ._connection . close ()
101- delattr (self , "_connection " )
100+ if hasattr (self , "_session " ):
101+ self ._session . remove ()
102+ delattr (self , "_session " )
102103
103104 @_enable_logging
104105 def execute (self , sql , * args , ** kwargs ):
@@ -275,33 +276,34 @@ def execute(self, sql, *args, **kwargs):
275276 # Infer whether app is defined
276277 assert flask .current_app
277278
278- # If no connections to any databases yet
279- if not hasattr (flask .g , "_connections " ):
280- setattr (flask .g , "_connections " , {})
281- connections = getattr (flask .g , "_connections " )
279+ # If no sessions for any databases yet
280+ if not hasattr (flask .g , "_sessions " ):
281+ setattr (flask .g , "_sessions " , {})
282+ sessions = getattr (flask .g , "_sessions " )
282283
283- # If not yet connected to this database
284+ # If no session yet for this database
284285 # https://flask.palletsprojects.com/en/1.1.x/appcontext/#storing-data
285- if self not in connections :
286+ # https://stackoverflow.com/a/34010159
287+ if self not in sessions :
286288
287289 # Connect to database
288- connections [self ] = self . _engine . connect ( )
290+ sessions [self ] = sqlalchemy . orm . scoping . scoped_session ( sqlalchemy . orm . sessionmaker ( bind = self . _engine ) )
289291
290- # Disconnect from database later
292+ # Remove session later
291293 if _teardown_appcontext not in flask .current_app .teardown_appcontext_funcs :
292294 flask .current_app .teardown_appcontext (_teardown_appcontext )
293295
294- # Use this connection
295- connection = connections [self ]
296+ # Use this session
297+ session = sessions [self ]
296298
297299 except (ModuleNotFoundError , AssertionError ):
298300
299301 # If no connection yet
300- if not hasattr (self , "_connection " ):
301- self ._connection = self . _engine . connect ( )
302+ if not hasattr (self , "_session " ):
303+ self ._session = sqlalchemy . orm . scoping . scoped_session ( sqlalchemy . orm . sessionmaker ( bind = self . _engine ) )
302304
303- # Use this connection
304- connection = self ._connection
305+ # Use this session
306+ session = self ._session
305307
306308 # Catch SQLAlchemy warnings
307309 with warnings .catch_warnings ():
@@ -321,10 +323,10 @@ def execute(self, sql, *args, **kwargs):
321323
322324 # Execute statement
323325 if self ._autocommit :
324- connection .execute (sqlalchemy .text ("BEGIN" ))
325- result = connection .execute (sqlalchemy .text (statement ))
326+ session .execute (sqlalchemy .text ("BEGIN" ))
327+ result = session .execute (sqlalchemy .text (statement ))
326328 if self ._autocommit :
327- connection .execute (sqlalchemy .text ("COMMIT" ))
329+ session .execute (sqlalchemy .text ("COMMIT" ))
328330
329331 # Check for end of transaction
330332 if command in ["COMMIT" , "ROLLBACK" ]:
@@ -357,7 +359,7 @@ def execute(self, sql, *args, **kwargs):
357359 elif command == "INSERT" :
358360 if self ._engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
359361 try :
360- result = connection .execute ("SELECT LASTVAL()" )
362+ result = session .execute ("SELECT LASTVAL()" )
361363 ret = result .first ()[0 ]
362364 except sqlalchemy .exc .OperationalError : # If lastval is not yet defined in this session
363365 ret = None
@@ -538,5 +540,5 @@ def _parse_placeholder(token):
538540def _teardown_appcontext (exception = None ):
539541 """Closes context's database connection, if any."""
540542 import flask
541- for connection in flask .g .pop ("_connections " , {}).values ():
542- connection . close ()
543+ for session in flask .g .pop ("_sessions " , {}).values ():
544+ session . remove ()
0 commit comments