55from .._exceptions import ConnectError , ConnectTimeout
66from .._types import URL , Headers , Origin , TimeoutDict
77from .._utils import exponential_backoff , get_logger , url_to_origin
8- from .base import (
9- AsyncByteStream ,
10- AsyncHTTPTransport ,
11- ConnectionState ,
12- NewConnectionRequired ,
13- )
8+ from .base import AsyncByteStream , AsyncHTTPTransport , NewConnectionRequired
149from .http import AsyncBaseHTTPConnection
1510from .http11 import AsyncHTTP11Connection
1611
@@ -25,6 +20,7 @@ def __init__(
2520 origin : Origin ,
2621 http1 : bool = True ,
2722 http2 : bool = False ,
23+ keepalive_expiry : float = None ,
2824 uds : str = None ,
2925 ssl_context : SSLContext = None ,
3026 socket : AsyncSocketStream = None ,
@@ -35,6 +31,7 @@ def __init__(
3531 self .origin = origin
3632 self .http1 = http1
3733 self .http2 = http2
34+ self .keepalive_expiry = keepalive_expiry
3835 self .uds = uds
3936 self .ssl_context = SSLContext () if ssl_context is None else ssl_context
4037 self .socket = socket
@@ -57,20 +54,58 @@ def __init__(
5754 self .backend = AutoBackend () if backend is None else backend
5855
5956 def __repr__ (self ) -> str :
60- http_version = "UNKNOWN"
61- if self .is_http11 :
62- http_version = "HTTP/1.1"
63- elif self .is_http2 :
64- http_version = "HTTP/2"
65- return f"<AsyncHTTPConnection http_version={ http_version } state={ self .state } >"
57+ return f"<AsyncHTTPConnection [{ self .info ()} ]>"
6658
6759 def info (self ) -> str :
6860 if self .connection is None :
69- return "Not connected"
70- elif self .state == ConnectionState .PENDING :
71- return "Connecting"
61+ return "Connection failed" if self .connect_failed else "Connecting"
7262 return self .connection .info ()
7363
64+ def should_close (self ) -> bool :
65+ """
66+ Return `True` if the connection is in a state where it should be closed.
67+ This occurs when any of the following occur:
68+
69+ * There are no active requests on an HTTP/1.1 connection, and the underlying
70+ socket is readable. The only valid state the socket can be readable in
71+ if this occurs is when the b"" EOF marker is about to be returned,
72+ indicating a server disconnect.
73+ * There are no active requests being made and the keepalive timeout has passed.
74+ """
75+ if self .connection is None :
76+ return False
77+ return self .connection .should_close ()
78+
79+ def is_idle (self ) -> bool :
80+ """
81+ Return `True` if the connection is currently idle.
82+ """
83+ if self .connection is None :
84+ return False
85+ return self .connection .is_idle ()
86+
87+ def is_closed (self ) -> bool :
88+ if self .connection is None :
89+ return self .connect_failed
90+ return self .connection .is_closed ()
91+
92+ def is_available (self ) -> bool :
93+ """
94+ Return `True` if the connection is currently able to accept an outgoing request.
95+ This occurs when any of the following occur:
96+
97+ * The connection has not yet been opened, and HTTP/2 support is enabled.
98+ We don't *know* at this point if we'll end up on an HTTP/2 connection or
99+ not, but we *might* do, so we indicate availability.
100+ * The connection has been opened, and is currently idle.
101+ * The connection is open, and is an HTTP/2 connection. The connection must
102+ also not currently be exceeding the maximum number of allowable concurrent
103+ streams and must not have exhausted the maximum total number of stream IDs.
104+ """
105+ if self .connection is None :
106+ return self .http2 and not self .is_closed
107+ return self .connection .is_available ()
108+
74109 @property
75110 def request_lock (self ) -> AsyncLock :
76111 # We do this lazily, to make sure backend autodetection always
@@ -91,18 +126,16 @@ async def handle_async_request(
91126 timeout = cast (TimeoutDict , extensions .get ("timeout" , {}))
92127
93128 async with self .request_lock :
94- if self .state == ConnectionState .PENDING :
129+ if self .connection is None :
130+ if self .connect_failed :
131+ raise NewConnectionRequired ()
95132 if not self .socket :
96133 logger .trace (
97134 "open_socket origin=%r timeout=%r" , self .origin , timeout
98135 )
99136 self .socket = await self ._open_socket (timeout )
100137 self ._create_connection (self .socket )
101- elif self .state in (ConnectionState .READY , ConnectionState .IDLE ):
102- pass
103- elif self .state == ConnectionState .ACTIVE and self .is_http2 :
104- pass
105- else :
138+ elif not self .connection .is_available ():
106139 raise NewConnectionRequired ()
107140
108141 assert self .connection is not None
@@ -159,33 +192,24 @@ def _create_connection(self, socket: AsyncSocketStream) -> None:
159192
160193 self .is_http2 = True
161194 self .connection = AsyncHTTP2Connection (
162- socket = socket , backend = self .backend , ssl_context = self .ssl_context
195+ socket = socket ,
196+ keepalive_expiry = self .keepalive_expiry ,
197+ backend = self .backend ,
163198 )
164199 else :
165200 self .is_http11 = True
166201 self .connection = AsyncHTTP11Connection (
167- socket = socket , ssl_context = self .ssl_context
202+ socket = socket , keepalive_expiry = self .keepalive_expiry
168203 )
169204
170- @property
171- def state (self ) -> ConnectionState :
172- if self .connect_failed :
173- return ConnectionState .CLOSED
174- elif self .connection is None :
175- return ConnectionState .PENDING
176- return self .connection .get_state ()
177-
178- def is_socket_readable (self ) -> bool :
179- return self .connection is not None and self .connection .is_socket_readable ()
180-
181- def mark_as_ready (self ) -> None :
182- if self .connection is not None :
183- self .connection .mark_as_ready ()
184-
185- async def start_tls (self , hostname : bytes , timeout : TimeoutDict = None ) -> None :
205+ async def start_tls (
206+ self , hostname : bytes , ssl_context : SSLContext , timeout : TimeoutDict = None
207+ ) -> None :
186208 if self .connection is not None :
187209 logger .trace ("start_tls hostname=%r timeout=%r" , hostname , timeout )
188- self .socket = await self .connection .start_tls (hostname , timeout )
210+ self .socket = await self .connection .start_tls (
211+ hostname , ssl_context , timeout
212+ )
189213 logger .trace ("start_tls complete hostname=%r timeout=%r" , hostname , timeout )
190214
191215 async def aclose (self ) -> None :
0 commit comments