diff --git a/pytest_blockage.py b/pytest_blockage.py index cd27240..b07db24 100644 --- a/pytest_blockage.py +++ b/pytest_blockage.py @@ -18,7 +18,12 @@ class MockSmtpCall(Exception): pass -def block_http(whitelist): +def http_fail(host): + logger.warning('Denied HTTP connection to: %s' % host) + raise MockHttpCall(host) + + +def block_http_httplib(whitelist): def whitelisted(self, host, *args, **kwargs): try: string_type = basestring @@ -26,8 +31,8 @@ def whitelisted(self, host, *args, **kwargs): # python3 string_type = str if isinstance(host, string_type) and host not in whitelist: - logger.warning('Denied HTTP connection to: %s' % host) - raise MockHttpCall(host) + http_fail(host) + logger.debug('Allowed HTTP connection to: %s' % host) return self.old(host, *args, **kwargs) @@ -39,6 +44,34 @@ def whitelisted(self, host, *args, **kwargs): httplib.HTTPConnection.__init__ = whitelisted +def block_http_aiohttp(whitelist): + try: + from aiohttp.client import ClientSession + except ImportError: + # no aiohttp installed + return + + from yarl import URL + + def patch(self, method, url, **kwargs): + yurl = URL(url) + + if yurl.host not in whitelist: + http_fail(yurl.host) + + return self.old_request(method, url, **kwargs) + + if not getattr(ClientSession, 'blockage', False): + logger.debug('Monkey patching httplib') + ClientSession.old_request = ClientSession._request + ClientSession._request = patch + + +def block_http(whitelist): + block_http_httplib(whitelist) + block_http_aiohttp(whitelist) + + def block_smtp(whitelist): def whitelisted(self, host, *args, **kwargs): if isinstance(host, basestring) and host not in whitelist: