1717
1818from aiohttp import web_request
1919from aiohttp import web_response
20+ from aiohttp import hdrs
2021from .routing import Handler
2122
2223class DynamicMiddleware (typing .Sized ,
@@ -32,54 +33,86 @@ class DynamicMiddleware(typing.Sized,
3233 __middleware_version__ : int = 1
3334
3435 # List of middlewares tuples
35- # (middleware_name, middleware_handler)
36- _handlers : typing .List [typing .Tuple [str , typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]]]]
36+ # (domain_suffix, middleware_name, middleware_handler)
37+ _handlers : typing .List [typing .Tuple [str , str , typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]]]]
3738
3839 def __init__ (self ,
3940 middlewares : typing .List [typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]]] = [],
40- named_middlewares : typing .List [typing .Tuple [str , typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]]]] = []) -> None :
41+ named_middlewares : typing .List [typing .Tuple [str , str , typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]]]] = []) -> None :
4142 super ().__init__ ()
4243
4344 # Append unnamed and named middlewares
44- self ._handlers = [] if middlewares is None else [ (None , m ) for m in middlewares ]
45+ self ._handlers = [] if middlewares is None else [ ('' , None , m ) for m in middlewares ]
4546 self ._handlers .extend ([] if named_middlewares is None else named_middlewares )
4647
48+ # Sort by suffix match
49+ self ._handlers .sort (key = functools .cmp_to_key (lambda x , y : 1 if y .endswith (x ) else - 1 ))
50+
4751 @property
48- def handlers (self ) -> typing .List [typing .Tuple [str , typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]]]]:
52+ def handlers (self ) -> typing .List [typing .Tuple [str , str , typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]]]]:
53+ """
54+ Returns list of tuples (domain_suffix, middleware_name, middleware_handler)
55+ """
4956 return self .handlers
5057
51- def add_handler (self , middleware : typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]]) -> None :
58+ def add_handler (self , middleware : typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]], domain : str = '' , overwrite : bool = True ) -> None :
5259 """
53- Append middleware to the end of middlewares list.
60+ Append middleware to the end of middlewares list. Domain is optional and
61+ defines suffix for middleware.
5462 """
5563
56- self ._handlers .append ((None , middleware ))
64+ if domain is None :
65+ return False
66+
67+ self ._handlers .append ((domain , None , middleware ))
68+
69+ # TODO: Insert without sort
70+ # Sort by suffix match
71+ self ._handlers .sort (key = functools .cmp_to_key (lambda x , y : 1 if y [0 ].endswith (x [0 ]) else - 1 ))
72+
73+ return True
5774
58- def add_named_handler (self , middleware : typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]], name : str , overwrite : bool = True ) -> bool :
75+ def add_named_handler (self , middleware : typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]], name : str , domain : str = '' , overwrite : bool = True ) -> bool :
5976 """
6077 Append middleware to the end of middlewares list or replace existing by
61- the same name if `overwrite` is set to True.
78+ the same name if `overwrite` is set to True. Domain is optional and
79+ defines suffix for middleware.
80+ Named handler overwrites only by name. in case when domain match with
81+ another record, it is ignored.
6282 Returns True if added.
6383 """
6484
85+ if domain is None :
86+ return False
87+
6588 if name is not None :
6689 for i , m in enumerate (self ._handlers ):
67- if m [0 ] == name :
90+ if m [1 ] == name :
6891 if overwrite :
69- self ._handlers [i ] = (name , middleware )
92+ self ._handlers [i ] = (domain , name , middleware )
93+
94+ # TODO: Insert without sort
95+ # Sort by suffix match
96+ self ._handlers .sort (key = functools .cmp_to_key (lambda x , y : 1 if y [0 ].endswith (x [0 ]) else - 1 ))
97+
7098 return True
7199 else :
72100 return False
73101
74- self ._handlers .append ((name , middleware ))
102+ self ._handlers .append ((domain , name , middleware ))
103+
104+ # TODO: Insert without sort
105+ # Sort by suffix match
106+ self ._handlers .sort (key = functools .cmp_to_key (lambda x , y : 1 if y [0 ].endswith (x [0 ]) else - 1 ))
107+
75108 return True
76109
77110 def get_handler (self , index : int ) -> typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]]:
78111 """
79112 Returns middleware handler by index.
80113 """
81114
82- return self ._handlers [index ][1 ]
115+ return self ._handlers [index ][2 ]
83116
84117 def get_named_handler (self , name : str ) -> typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]]:
85118 """
@@ -88,19 +121,71 @@ def get_named_handler(self, name: str) -> typing.Callable[[web_request.Request,
88121
89122 if name is not None :
90123 for m in self ._handlers :
91- if m [0 ] == name :
92- return m [1 ]
124+ if m [1 ] == name :
125+ return m [2 ]
93126
94127 return None
95128
129+ def get_domain_handlers (self , domain : str ) -> typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]]:
130+ """
131+ Returns handlers for given domain suffix.
132+ """
133+
134+ result = []
135+
136+ if domain is not None :
137+ for m in self ._handlers :
138+ if m [0 ] == domain :
139+ result .append (m [2 ])
140+
141+ return result
142+
143+ def get_matching_domain_handlers (self , domain : str ) -> typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]]:
144+ """
145+ Returns handlers with domain matching given domain suffix.
146+ """
147+
148+ result = []
149+
150+ if domain is not None :
151+ for m in self ._handlers :
152+ if m [0 ].endswith (domain ):
153+ result .append (m [2 ])
154+
155+ return result
156+
96157 def contains_named_handler (self , name : str ) -> bool :
97158 """
98159 Checks if there exists middleware with specified name.
99160 """
100161
101162 if name is not None :
102163 for m in self ._handlers :
103- if m [0 ] == name :
164+ if m [1 ] == name :
165+ return True
166+
167+ return False
168+
169+ def contains_domain_handler (self , domain : str ) -> bool :
170+ """
171+ Checks if there exists middleware with given domain suffix.
172+ """
173+
174+ if domain is not None :
175+ for m in self ._handlers :
176+ if m [0 ] == domain :
177+ return True
178+
179+ return False
180+
181+ def contains_matching_domain_handler (self , domain : str ) -> bool :
182+ """
183+ Checks if there exists middleware matching given domain suffix.
184+ """
185+
186+ if domain is not None :
187+ for m in self ._handlers :
188+ if m [0 ].endswith (domain ):
104189 return True
105190
106191 return False
@@ -119,8 +204,25 @@ def del_named_handler(self, name: str) -> None:
119204
120205 if name is not None :
121206 for i , m in enumerate (self ._handlers ):
122- if m [0 ] == name :
207+ if m [1 ] == name :
123208 self ._handlers .pop (i )
209+ return
210+
211+ def del_domain_handlers (self , domain : str ) -> None :
212+ """
213+ Deletes handlers for given domain suffix.
214+ """
215+
216+ if domain is not None :
217+ self ._handlers [:] = [m for m in self ._handlers if m [0 ] != domain ]
218+
219+ def del_matching_domain_handlers (self , domain : str ) -> None :
220+ """
221+ Deletes handlers matching given domain suffix.
222+ """
223+
224+ if domain is not None :
225+ self ._handlers [:] = [m for m in self ._handlers if not m [0 ].endswith (domain )]
124226
125227 def del_handlers (self ) -> None :
126228 """
@@ -129,15 +231,25 @@ def del_handlers(self) -> None:
129231
130232 self ._handlers .clear ()
131233
132- def __iter__ (self ) -> typing .Iterator [typing .Tuple [str , typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]]]]:
234+ def __iter__ (self ) -> typing .Iterator [typing .Tuple [str , str , typing .Callable [[web_request .Request , Handler ], typing .Awaitable [web_response .StreamResponse ]]]]:
133235 return iter (self ._handlers )
134236
135237 def __len__ (self ) -> int :
136238 return len (self ._handlers )
137239
138240 async def __call__ (self , request : web_request .Request , handler : Handler ) -> web_response .StreamResponse :
139- # Rewrap and return
140- for n , h in reversed (self ._handlers ): # Do not remove n,
141- handler = functools .update_wrapper (functools .partial (h , handler = handler ), handler )
142-
143- return await handler (request )
241+ domain = request .headers .get (hdrs .HOST , None )
242+
243+ if domain is None :
244+ # Rewrap and return
245+ for s , n , h in reversed (self ._handlers ): # Do not remove n,
246+ handler = functools .update_wrapper (functools .partial (h , handler = handler ), handler )
247+
248+ return await handler (request )
249+ else :
250+ # Rewrap only matching suffix and return
251+ for s , n , h in reversed (self ._handlers ): # Do not remove n,
252+ if domain .endswith (s ):
253+ handler = functools .update_wrapper (functools .partial (h , handler = handler ), handler )
254+
255+ return await handler (request )
0 commit comments