@@ -43,6 +43,7 @@ def __init__(self, url, **kwargs):
4343 import os
4444 import re
4545 import sqlalchemy
46+ import sqlalchemy .orm
4647 import sqlite3
4748
4849 # Get logger
@@ -59,6 +60,11 @@ def __init__(self, url, **kwargs):
5960 # Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
6061 self ._engine = sqlalchemy .create_engine (url , ** kwargs ).execution_options (autocommit = False )
6162
63+ # Create a variable to hold the session. If None, autocommit is on.
64+ self ._Session = sqlalchemy .orm .session .sessionmaker (bind = self ._engine )
65+ self ._session = None
66+ self ._in_transaction = False
67+
6268 # Listener for connections
6369 def connect (dbapi_connection , connection_record ):
6470
@@ -90,9 +96,8 @@ def connect(dbapi_connection, connection_record):
9096 self ._logger .disabled = disabled
9197
9298 def __del__ (self ):
93- """Close database connection."""
94- if hasattr (self , "_connection" ):
95- self ._connection .close ()
99+ """Close database session and connection."""
100+ self ._close_session ()
96101
97102 @_enable_logging
98103 def execute (self , sql , * args , ** kwargs ):
@@ -125,6 +130,12 @@ def execute(self, sql, *args, **kwargs):
125130 if token .ttype in [sqlparse .tokens .Keyword .DDL , sqlparse .tokens .Keyword .DML ]:
126131 command = token .value .upper ()
127132 break
133+
134+ # Begin a new session, if transaction opened by caller (not using autocommit)
135+ elif token .value .upper () in ["BEGIN" , "START" ]:
136+ if self ._in_transaction :
137+ raise RuntimeError ("transaction already open" )
138+ self ._in_transaction = True
128139 else :
129140 command = None
130141
@@ -272,6 +283,10 @@ def execute(self, sql, *args, **kwargs):
272283 statement = "" .join ([str (token ) for token in tokens ])
273284
274285 # Connect to database (for transactions' sake)
286+ if self ._session is None :
287+ self ._session = self ._Session ()
288+
289+ # Set up a Flask app teardown function to close session at teardown
275290 try :
276291
277292 # Infer whether Flask is installed
@@ -280,29 +295,17 @@ def execute(self, sql, *args, **kwargs):
280295 # Infer whether app is defined
281296 assert flask .current_app
282297
283- # If no connection for app's current request yet
284- if not hasattr (flask .g , "_connection" ):
298+ # Disconnect later - but only once
299+ if not hasattr (self , "_teardown_appcontext_added" ):
300+ self ._teardown_appcontext_added = True
285301
286- # Connect now
287- flask .g ._connection = self ._engine .connect ()
288-
289- # Disconnect later
290302 @flask .current_app .teardown_appcontext
291303 def shutdown_session (exception = None ):
292- if hasattr (flask .g , "_connection" ):
293- flask .g ._connection .close ()
294-
295- # Use this connection
296- connection = flask .g ._connection
304+ """Close any existing session on app context teardown."""
305+ self ._close_session ()
297306
298307 except (ModuleNotFoundError , AssertionError ):
299-
300- # If no connection yet
301- if not hasattr (self , "_connection" ):
302- self ._connection = self ._engine .connect ()
303-
304- # Use this connection
305- connection = self ._connection
308+ pass
306309
307310 # Catch SQLAlchemy warnings
308311 with warnings .catch_warnings ():
@@ -316,8 +319,14 @@ def shutdown_session(exception=None):
316319 # Join tokens into statement, abbreviating binary data as <class 'bytes'>
317320 _statement = "" .join ([str (bytes ) if token .ttype == sqlparse .tokens .Other else str (token ) for token in tokens ])
318321
322+ # If COMMIT or ROLLBACK, turn on autocommit mode
323+ if command in ["COMMIT" , "ROLLBACK" ] and "TO" not in (token .value for token in tokens ):
324+ if not self ._in_transaction :
325+ raise RuntimeError ("transactions must be opened with BEGIN or START TRANSACTION" )
326+ self ._in_transaction = False
327+
319328 # Execute statement
320- result = connection .execute (sqlalchemy .text (statement ))
329+ result = self . _session .execute (sqlalchemy .text (statement ))
321330
322331 # Return value
323332 ret = True
@@ -346,7 +355,7 @@ def shutdown_session(exception=None):
346355 elif command == "INSERT" :
347356 if self ._engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
348357 try :
349- result = connection .execute ("SELECT LASTVAL()" )
358+ result = self . _session .execute ("SELECT LASTVAL()" )
350359 ret = result .first ()[0 ]
351360 except sqlalchemy .exc .OperationalError : # If lastval is not yet defined in this session
352361 ret = None
@@ -357,6 +366,10 @@ def shutdown_session(exception=None):
357366 elif command in ["DELETE" , "UPDATE" ]:
358367 ret = result .rowcount
359368
369+ # If autocommit is on, commit
370+ if not self ._in_transaction :
371+ self ._session .commit ()
372+
360373 # If constraint violated, return None
361374 except sqlalchemy .exc .IntegrityError as e :
362375 self ._logger .debug (termcolor .colored (statement , "yellow" ))
@@ -376,6 +389,13 @@ def shutdown_session(exception=None):
376389 self ._logger .debug (termcolor .colored (_statement , "green" ))
377390 return ret
378391
392+ def _close_session (self ):
393+ """Closes any existing session and resets instance variables."""
394+ if self ._session is not None :
395+ self ._session .close ()
396+ self ._session = None
397+ self ._in_transaction = False
398+
379399 def _escape (self , value ):
380400 """
381401 Escapes value using engine's conversion function.
0 commit comments