@@ -120,6 +120,14 @@ def execute(self, sql, *args, **kwargs):
120120 if len (args ) > 0 and len (kwargs ) > 0 :
121121 raise RuntimeError ("cannot pass both named and positional parameters" )
122122
123+ # Infer command from (unflattened) statement
124+ for token in statements [0 ]:
125+ if token .ttype in [sqlparse .tokens .Keyword .DDL , sqlparse .tokens .Keyword .DML ]:
126+ command = token .value .upper ()
127+ break
128+ else :
129+ command = None
130+
123131 # Flatten statement
124132 tokens = list (statements [0 ].flatten ())
125133
@@ -313,45 +321,41 @@ def shutdown_session(exception=None):
313321
314322 # Return value
315323 ret = True
316- if tokens [0 ].ttype == sqlparse .tokens .Keyword .DML :
317-
318- # Uppercase token's value
319- value = tokens [0 ].value .upper ()
320-
321- # If SELECT, return result set as list of dict objects
322- if value == "SELECT" :
323-
324- # Coerce types
325- rows = [dict (row ) for row in result .fetchall ()]
326- for row in rows :
327- for column in row :
328-
329- # Coerce decimal.Decimal objects to float objects
330- # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
331- if type (row [column ]) is decimal .Decimal :
332- row [column ] = float (row [column ])
333-
334- # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
335- elif type (row [column ]) is memoryview :
336- row [column ] = bytes (row [column ])
337-
338- # Rows to be returned
339- ret = rows
340-
341- # If INSERT, return primary key value for a newly inserted row (or None if none)
342- elif value == "INSERT" :
343- if self ._engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
344- try :
345- result = connection .execute ("SELECT LASTVAL()" )
346- ret = result .first ()[0 ]
347- except sqlalchemy .exc .OperationalError : # If lastval is not yet defined in this session
348- ret = None
349- else :
350- ret = result .lastrowid if result .rowcount == 1 else None
351-
352- # If DELETE or UPDATE, return number of rows matched
353- elif value in ["DELETE" , "UPDATE" ]:
354- ret = result .rowcount
324+
325+ # If SELECT, return result set as list of dict objects
326+ if command == "SELECT" :
327+
328+ # Coerce types
329+ rows = [dict (row ) for row in result .fetchall ()]
330+ for row in rows :
331+ for column in row :
332+
333+ # Coerce decimal.Decimal objects to float objects
334+ # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
335+ if type (row [column ]) is decimal .Decimal :
336+ row [column ] = float (row [column ])
337+
338+ # Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
339+ elif type (row [column ]) is memoryview :
340+ row [column ] = bytes (row [column ])
341+
342+ # Rows to be returned
343+ ret = rows
344+
345+ # If INSERT, return primary key value for a newly inserted row (or None if none)
346+ elif command == "INSERT" :
347+ if self ._engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
348+ try :
349+ result = connection .execute ("SELECT LASTVAL()" )
350+ ret = result .first ()[0 ]
351+ except sqlalchemy .exc .OperationalError : # If lastval is not yet defined in this session
352+ ret = None
353+ else :
354+ ret = result .lastrowid if result .rowcount == 1 else None
355+
356+ # If DELETE or UPDATE, return number of rows matched
357+ elif command in ["DELETE" , "UPDATE" ]:
358+ ret = result .rowcount
355359
356360 # If constraint violated, return None
357361 except sqlalchemy .exc .IntegrityError as e :
0 commit comments