3939
4040
4141if t .TYPE_CHECKING :
42+ from ssl import SSLContext
43+
44+ import typing_extensions as te
45+
4246 from ..._deadline import Deadline
47+ from ...addressing import (
48+ Address ,
49+ ResolvedAddress ,
50+ )
4351
4452
4553log = logging .getLogger ("neo4j.io" )
@@ -63,7 +71,11 @@ def __str__(self):
6371
6472
6573class AsyncBoltSocket (AsyncBoltSocketBase ):
66- async def _parse_handshake_response_v1 (self , ctx , response ):
74+ async def _parse_handshake_response_v1 (
75+ self ,
76+ ctx : HandshakeCtx ,
77+ response : bytes ,
78+ ) -> tuple [int , int ]:
6779 agreed_version = response [- 1 ], response [- 2 ]
6880 log .debug (
6981 "[#%04X] S: <HANDSHAKE> 0x%06X%02X" ,
@@ -73,7 +85,11 @@ async def _parse_handshake_response_v1(self, ctx, response):
7385 )
7486 return agreed_version
7587
76- async def _parse_handshake_response_v2 (self , ctx , response ):
88+ async def _parse_handshake_response_v2 (
89+ self ,
90+ ctx : HandshakeCtx ,
91+ response : bytes ,
92+ ) -> tuple [int , int ]:
7793 ctx .ctx = "handshake v2 offerings count"
7894 num_offerings = await self ._read_varint (ctx )
7995 offerings = []
@@ -85,7 +101,7 @@ async def _parse_handshake_response_v2(self, ctx, response):
85101 ctx .ctx = "handshake v2 capabilities"
86102 _capabilities_offer = await self ._read_varint (ctx )
87103
88- if log .getEffectiveLevel () > = logging .DEBUG :
104+ if log .getEffectiveLevel () < = logging .DEBUG :
89105 log .debug (
90106 "[#%04X] S: <HANDSHAKE> %s [%i] %s %s" ,
91107 ctx .local_port ,
@@ -125,7 +141,7 @@ async def _parse_handshake_response_v2(self, ctx, response):
125141
126142 return chosen_version
127143
128- async def _read_varint (self , ctx ) :
144+ async def _read_varint (self , ctx : HandshakeCtx ) -> int :
129145 next_byte = (await self ._handshake_read (ctx , 1 ))[0 ]
130146 res = next_byte & 0x7F
131147 i = 0
@@ -136,15 +152,15 @@ async def _read_varint(self, ctx):
136152 return res
137153
138154 @staticmethod
139- def _encode_varint (n ) :
155+ def _encode_varint (n : int ) -> bytearray :
140156 res = bytearray ()
141157 while n >= 0x80 :
142158 res .append (n & 0x7F | 0x80 )
143159 n >>= 7
144160 res .append (n )
145161 return res
146162
147- async def _handshake_read (self , ctx , n ) :
163+ async def _handshake_read (self , ctx : HandshakeCtx , n : int ) -> bytes :
148164 original_timeout = self .gettimeout ()
149165 self .settimeout (ctx .deadline .to_timeout ())
150166 try :
@@ -193,7 +209,11 @@ async def _handshake_send(self, ctx, data):
193209 finally :
194210 self .settimeout (original_timeout )
195211
196- async def _handshake (self , resolved_address , deadline ):
212+ async def _handshake (
213+ self ,
214+ resolved_address : ResolvedAddress ,
215+ deadline : Deadline ,
216+ ) -> tuple [tuple [int , int ], bytes , bytes ]:
197217 """
198218 Perform BOLT handshake.
199219
@@ -204,16 +224,16 @@ async def _handshake(self, resolved_address, deadline):
204224 """
205225 local_port = self .getsockname ()[1 ]
206226
207- if log . getEffectiveLevel () >= logging . DEBUG :
208- handshake = self . Bolt . get_handshake ()
209- handshake = struct .unpack (">16B" , handshake )
210- handshake = [
211- handshake [i : i + 4 ] for i in range (0 , len (handshake ), 4 )
227+ handshake = self . Bolt . get_handshake ()
228+ if log . getEffectiveLevel () <= logging . DEBUG :
229+ handshake_bytes : t . Sequence = struct .unpack (">16B" , handshake )
230+ handshake_bytes = [
231+ handshake [i : i + 4 ] for i in range (0 , len (handshake_bytes ), 4 )
212232 ]
213233
214234 supported_versions = [
215235 f"0x{ vx [0 ]:02X} { vx [1 ]:02X} { vx [2 ]:02X} { vx [3 ]:02X} "
216- for vx in handshake
236+ for vx in handshake_bytes
217237 ]
218238
219239 log .debug (
@@ -227,7 +247,7 @@ async def _handshake(self, resolved_address, deadline):
227247 * supported_versions ,
228248 )
229249
230- request = self .Bolt .MAGIC_PREAMBLE + self . Bolt . get_handshake ()
250+ request = self .Bolt .MAGIC_PREAMBLE + handshake
231251
232252 ctx = HandshakeCtx (
233253 ctx = "handshake opening" ,
@@ -273,14 +293,14 @@ async def _handshake(self, resolved_address, deadline):
273293 @classmethod
274294 async def connect (
275295 cls ,
276- address ,
296+ address : Address ,
277297 * ,
278- tcp_timeout ,
279- deadline ,
280- custom_resolver ,
281- ssl_context ,
282- keep_alive ,
283- ):
298+ tcp_timeout : float | None ,
299+ deadline : Deadline ,
300+ custom_resolver : t . Callable | None ,
301+ ssl_context : SSLContext | None ,
302+ keep_alive : bool ,
303+ ) -> tuple [ te . Self , tuple [ int , int ], bytes , bytes ] :
284304 """
285305 Connect and perform a handshake.
286306
@@ -313,10 +333,10 @@ async def connect(
313333 )
314334 return s , agreed_version , handshake , response
315335 except (BoltError , DriverError , OSError ) as error :
316- try :
317- local_port = s . getsockname ()[ 1 ]
318- except (OSError , AttributeError , TypeError ):
319- local_port = 0
336+ local_port = 0
337+ if isinstance ( s , cls ):
338+ with suppress (OSError , AttributeError , TypeError ):
339+ local_port = s . getsockname ()[ 1 ]
320340 err_str = error .__class__ .__name__
321341 if str (error ):
322342 err_str += ": " + str (error )
@@ -331,10 +351,10 @@ async def connect(
331351 errors .append (error )
332352 failed_addresses .append (resolved_address )
333353 except asyncio .CancelledError :
334- try :
335- local_port = s . getsockname ()[ 1 ]
336- except (OSError , AttributeError , TypeError ):
337- local_port = 0
354+ local_port = 0
355+ if isinstance ( s , cls ):
356+ with suppress (OSError , AttributeError , TypeError ):
357+ local_port = s . getsockname ()[ 1 ]
338358 log .debug (
339359 "[#%04X] C: <CANCELED> %s" , local_port , resolved_address
340360 )
0 commit comments