3737from prompt_toolkit .layout .processors import ConditionalProcessor , HighlightMatchingBracketProcessor
3838from prompt_toolkit .lexers import PygmentsLexer
3939from prompt_toolkit .shortcuts import CompleteStyle , PromptSession
40- from pymysql import OperationalError , err
41- from pymysql .constants .ER import HANDSHAKE_ERROR
40+ import pymysql
41+ from pymysql .constants .ER import ERROR_CODE_ACCESS_DENIED , HANDSHAKE_ERROR
4242from pymysql .cursors import Cursor
4343import sqlglot
4444import sqlparse
@@ -214,7 +214,7 @@ def close(self) -> None:
214214 def register_special_commands (self ) -> None :
215215 special .register_special_command (self .change_db , "use" , "\\ u" , "Change to a new database." , aliases = ["\\ u" ])
216216 special .register_special_command (
217- self .change_db ,
217+ self .manual_reconnect ,
218218 "connect" ,
219219 "\\ r" ,
220220 "Reconnect to the database. Optional database argument." ,
@@ -261,6 +261,18 @@ def register_special_commands(self) -> None:
261261 self .change_prompt_format , "prompt" , "\\ R" , "Change prompt format." , aliases = ["\\ R" ], case_sensitive = True
262262 )
263263
264+ def manual_reconnect (self , arg : str = "" , ** _ ) -> Generator [tuple , None , None ]:
265+ """
266+ Interactive method to use for the \r command, so that the utility method
267+ may be cleanly used elsewhere.
268+ """
269+ if not self .reconnect (database = arg ):
270+ yield (None , None , None , "Not connected" )
271+ elif not arg or arg == '``' :
272+ yield (None , None , None , None )
273+ else :
274+ yield self .change_db (arg ).send (None )
275+
264276 def enable_show_warnings (self , ** _ ) -> Generator [tuple , None , None ]:
265277 self .show_warnings = True
266278 msg = "Show warnings enabled."
@@ -301,13 +313,18 @@ def change_db(self, arg: str, **_) -> Generator[tuple, None, None]:
301313 return
302314
303315 assert isinstance (self .sqlexecute , SQLExecute )
304- self .sqlexecute .change_db (arg )
316+
317+ if self .sqlexecute .dbname == arg :
318+ msg = f'You are already connected to database "{ self .sqlexecute .dbname } " as user "{ self .sqlexecute .user } "'
319+ else :
320+ self .sqlexecute .change_db (arg )
321+ msg = f'You are now connected to database "{ self .sqlexecute .dbname } " as user "{ self .sqlexecute .user } "'
305322
306323 yield (
307324 None ,
308325 None ,
309326 None ,
310- f'You are now connected to database " { self . sqlexecute . dbname } " as user " { self . sqlexecute . user } "' ,
327+ msg ,
311328 )
312329
313330 def execute_from_file (self , arg : str , ** _ ) -> Iterable [tuple ]:
@@ -526,7 +543,7 @@ def _connect() -> None:
526543 }
527544 try :
528545 self .sqlexecute = SQLExecute (** conn_config )
529- except OperationalError as e :
546+ except pymysql . OperationalError as e :
530547 if e .args [0 ] == ERROR_CODE_ACCESS_DENIED :
531548 if password_from_file is not None :
532549 conn_config ["password" ] = password_from_file
@@ -547,7 +564,7 @@ def _connect() -> None:
547564 self .echo (f"Connecting to socket { socket } , owned by user { socket_owner } " , err = True )
548565 try :
549566 _connect ()
550- except OperationalError as e :
567+ except pymysql . OperationalError as e :
551568 # These are "Can't open socket" and 2x "Can't connect"
552569 if [code for code in (2001 , 2002 , 2003 ) if code == e .args [0 ]]:
553570 self .logger .debug ("Database connection failed: %r." , e )
@@ -900,19 +917,12 @@ def one_iteration(text: str | None = None) -> None:
900917 output_res (res , start )
901918 special .unset_once_if_written (self .post_redirect_command )
902919 special .flush_pipe_once_if_written (self .post_redirect_command )
903- except err .InterfaceError :
904- logger .debug ("Attempting to reconnect." )
905- self .echo ("Reconnecting..." , fg = "yellow" )
906- try :
907- sqlexecute .connect ()
908- logger .debug ("Reconnected successfully." )
909- one_iteration (text )
910- return # OK to just return, cuz the recursion call runs to the end.
911- except OperationalError as e2 :
912- logger .debug ("Reconnect failed. e: %r" , e2 )
913- self .echo (str (e2 ), err = True , fg = "red" )
914- # If reconnection failed, don't proceed further.
920+ except pymysql .err .InterfaceError :
921+ # attempt to reconnect
922+ if not self .reconnect ():
915923 return
924+ one_iteration (text )
925+ return # OK to just return, cuz the recursion call runs to the end.
916926 except EOFError as e :
917927 raise e
918928 except KeyboardInterrupt :
@@ -943,21 +953,14 @@ def one_iteration(text: str | None = None) -> None:
943953 self .echo ("Did not get a connection id, skip cancelling query" , err = True , fg = "red" )
944954 except NotImplementedError :
945955 self .echo ("Not Yet Implemented." , fg = "yellow" )
946- except OperationalError as e1 :
956+ except pymysql . OperationalError as e1 :
947957 logger .debug ("Exception: %r" , e1 )
948958 if e1 .args [0 ] in (2003 , 2006 , 2013 ):
949- logger .debug ("Attempting to reconnect." )
950- self .echo ("Reconnecting..." , fg = "yellow" )
951- try :
952- sqlexecute .connect ()
953- logger .debug ("Reconnected successfully." )
954- one_iteration (text )
955- return # OK to just return, cuz the recursion call runs to the end.
956- except OperationalError as e2 :
957- logger .debug ("Reconnect failed. e: %r" , e2 )
958- self .echo (str (e2 ), err = True , fg = "red" )
959- # If reconnection failed, don't proceed further.
959+ # attempt to reconnect
960+ if not self .reconnect ():
960961 return
962+ one_iteration (text )
963+ return # OK to just return, cuz the recursion call runs to the end.
961964 else :
962965 logger .error ("sql: %r, error: %r" , text , e1 )
963966 logger .error ("traceback: %r" , traceback .format_exc ())
@@ -1029,6 +1032,58 @@ def one_iteration(text: str | None = None) -> None:
10291032 if not self .less_chatty :
10301033 self .echo ("Goodbye!" )
10311034
1035+ def reconnect (self , database : str = "" ) -> bool :
1036+ """
1037+ Attempt to reconnect to the server. Return True if successful,
1038+ False if unsuccessful.
1039+
1040+ The "database" argument is used only to improve messages.
1041+ """
1042+ assert self .sqlexecute is not None
1043+ assert self .sqlexecute .conn is not None
1044+
1045+ # First pass with ping(reconnect=False) and minimal feedback levels. This definitely
1046+ # works as expected, and is a good idea especially when "connect" was used as a
1047+ # synonym for "use".
1048+ try :
1049+ self .sqlexecute .conn .ping (reconnect = False )
1050+ if not database :
1051+ self .echo ("Already connected." , fg = "yellow" )
1052+ return True
1053+ except pymysql .err .Error :
1054+ pass
1055+
1056+ # Second pass with ping(reconnect=True). It is not demonstrated that this pass ever
1057+ # gives the benefit it is looking for, _ie_ preserves session state. We need to test
1058+ # this with connection pooling.
1059+ try :
1060+ old_connection_id = self .sqlexecute .connection_id
1061+ self .logger .debug ("Attempting to reconnect." )
1062+ self .echo ("Reconnecting..." , fg = "yellow" )
1063+ self .sqlexecute .conn .ping (reconnect = True )
1064+ self .logger .debug ("Reconnected successfully." )
1065+ self .echo ("Reconnected successfully." , fg = "yellow" )
1066+ self .sqlexecute .reset_connection_id ()
1067+ if old_connection_id != self .sqlexecute .connection_id :
1068+ self .echo ("Any session state was reset." , fg = "red" )
1069+ return True
1070+ except pymysql .err .Error :
1071+ pass
1072+
1073+ # Third pass with sqlexecute.connect() should always work, but always resets session state.
1074+ try :
1075+ self .logger .debug ("Creating new connection" )
1076+ self .echo ("Creating new connection..." , fg = "yellow" )
1077+ self .sqlexecute .connect ()
1078+ self .logger .debug ("New connection created successfully." )
1079+ self .echo ("New connection created successfully." , fg = "yellow" )
1080+ self .echo ("Any session state was reset." , fg = "red" )
1081+ return True
1082+ except pymysql .OperationalError as e :
1083+ self .logger .debug ("Reconnect failed. e: %r" , e )
1084+ self .echo (str (e ), err = True , fg = "red" )
1085+ return False
1086+
10321087 def log_output (self , output : str ) -> None :
10331088 """Log the output in the audit log, if it's enabled."""
10341089 if isinstance (self .logfile , TextIOWrapper ):
0 commit comments