diff --git a/pyproject.toml b/pyproject.toml index 5a1517f..7523817 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "kraw" -version = "0.3.2" +version = "0.3.3" description = "Reddit API wrapper for Knew Karma" authors = ["Richard Mwewa "] license = "GPL-3.0+" diff --git a/src/kraw/connection.py b/src/kraw/connection.py index def2d10..4a6da63 100644 --- a/src/kraw/connection.py +++ b/src/kraw/connection.py @@ -4,7 +4,7 @@ from types import SimpleNamespace from typing import Optional, Callable, List, Dict, Union -from aiohttp import ClientSession +import aiohttp from . import dummies @@ -30,15 +30,20 @@ def __init__(self, headers: Dict): async def send_request( self, - session: ClientSession, + session: aiohttp.ClientSession, endpoint: str, params: Optional[Dict] = None, proxy: Optional[str] = None, + proxy_auth: Optional[aiohttp.BasicAuth] = None, ) -> Union[Dict, List, bool, None]: try: async with session.get( - url=endpoint, headers=self._headers, params=params, proxy=proxy + url=endpoint, + headers=self._headers, + params=params, + proxy=proxy, + proxy_auth=proxy_auth, ) as response: response.raise_for_status() response_data: Union[Dict, List] = await response.json() @@ -49,7 +54,7 @@ async def send_request( async def paginate_response( self, - session: ClientSession, + session: aiohttp.ClientSession, endpoint: str, limit: int, parser: Callable, @@ -57,6 +62,7 @@ async def paginate_response( status: Optional[dummies.Status] = None, params: Optional[Dict] = None, proxy: Optional[str] = None, + proxy_auth: Optional[aiohttp.BasicAuth] = None, is_post_comments: Optional[bool] = False, ) -> List[SimpleNamespace]: @@ -77,6 +83,7 @@ async def paginate_response( ), params=params, proxy=proxy, + proxy_auth=proxy_auth, ) if is_post_comments: @@ -84,6 +91,7 @@ async def paginate_response( session=session, endpoint=endpoint, proxy=proxy, + proxy_auth=proxy_auth, response=parser(response[1]), parser=parser, limit=limit, @@ -137,7 +145,7 @@ async def paginate_response( async def _paginate_more_items( self, - session: ClientSession, + session: aiohttp.ClientSession, more_items_ids: List[str], endpoint: str, parser: Callable, @@ -145,6 +153,7 @@ async def _paginate_more_items( limit: int, status: Optional[dummies.Status] = None, proxy: Optional[str] = None, + proxy_auth: Optional[aiohttp.BasicAuth] = None, message: Optional[dummies.Message] = None, ): # Track how many more items are needed to meet the overall limit @@ -163,7 +172,10 @@ async def _paginate_more_items( more_endpoint = f"{endpoint}?comment={more_id}" # Make an asynchronous request to fetch the additional comments. more_response = await self.send_request( - session=session, endpoint=more_endpoint, proxy=proxy + session=session, + endpoint=more_endpoint, + proxy=proxy, + proxy_auth=proxy_auth, ) # Extract the items (comments) from the response. more_items = parser(response=more_response[1]) @@ -209,6 +221,7 @@ async def _process_post_comments(self, **kwargs): await self._paginate_more_items( session=kwargs.get("session"), proxy=kwargs.get("proxy"), + proxy_auth=kwargs.get("proxy_auth"), message=kwargs.get("message"), status=kwargs.get("status"), fetched_items=items, diff --git a/src/kraw/reddit.py b/src/kraw/reddit.py index f1c1f3f..65f45e9 100644 --- a/src/kraw/reddit.py +++ b/src/kraw/reddit.py @@ -1,8 +1,8 @@ from types import SimpleNamespace from typing import Literal, Union, Optional, List, Dict +import aiohttp import karmakaze -from aiohttp import ClientSession from . import dummies @@ -39,8 +39,9 @@ def __init__(self, headers: Dict, time_format: TIME_FORMAT = "locale"): async def infra_status( self, - session: ClientSession, + session: aiohttp.ClientSession, proxy: Optional[str] = None, + proxy_auth: Optional[aiohttp.BasicAuth] = None, message: Optional[dummies.Message] = None, status: Optional[dummies.Status] = None, ) -> Union[List[Dict], None]: @@ -50,8 +51,9 @@ async def infra_status( status_response: Dict = await self.connection.send_request( session=session, - proxy=proxy, endpoint=self.connection.endpoints.infra_status, + proxy=proxy, + proxy_auth=proxy_auth, ) indicator = status_response.get("status").get("indicator") @@ -83,12 +85,13 @@ async def infra_status( async def comments( self, - session: ClientSession, + session: aiohttp.ClientSession, kind: COMMENTS_KIND, limit: int, sort: SORT, timeframe: TIMEFRAME, proxy: Optional[str] = None, + proxy_auth: Optional[aiohttp.BasicAuth] = None, message: Optional[dummies.Message] = None, status: Optional[dummies.Status] = None, **kwargs: str, @@ -111,6 +114,7 @@ async def comments( session=session, endpoint=endpoint, proxy=proxy, + proxy_auth=proxy_auth, params=params, limit=limit, parser=self._parse.comments, @@ -128,8 +132,9 @@ async def post( self, id: str, subreddit: str, - session: ClientSession, + session: aiohttp.ClientSession, proxy: Optional[str] = None, + proxy_auth: Optional[aiohttp.BasicAuth] = None, status: Optional[dummies.Status] = None, ) -> SimpleNamespace: if status: @@ -139,6 +144,7 @@ async def post( session=session, endpoint=f"{self.connection.endpoints.subreddit}/{subreddit}/comments/{id}.json", proxy=proxy, + proxy_auth=proxy_auth, ) sanitised_response = self._parse.post(response=response) @@ -146,12 +152,13 @@ async def post( async def posts( self, - session: ClientSession, + session: aiohttp.ClientSession, kind: POSTS_KIND, limit: int, sort: SORT, timeframe: TIMEFRAME, proxy: Optional[str] = None, + proxy_auth: Optional[aiohttp.BasicAuth] = None, message: Optional[dummies.Message] = None, status: Optional[dummies.Status] = None, **kwargs: str, @@ -191,6 +198,7 @@ async def posts( session=session, endpoint=endpoint, proxy=proxy, + proxy_auth=proxy_auth, params=params, limit=limit, parser=self._parse.posts, @@ -205,12 +213,13 @@ async def posts( async def search( self, - session: ClientSession, + session: aiohttp.ClientSession, kind: SEARCH_KIND, query: str, limit: int, sort: SORT, proxy: Optional[str] = None, + proxy_auth: Optional[aiohttp.BasicAuth] = None, message: Optional[dummies.Message] = None, status: Optional[dummies.Status] = None, ) -> List[SimpleNamespace]: @@ -239,6 +248,7 @@ async def search( session=session, endpoint=endpoint, proxy=proxy, + proxy_auth=proxy_auth, params=params, parser=parser, limit=limit, @@ -254,8 +264,9 @@ async def search( async def subreddit( self, name: str, - session: ClientSession, + session: aiohttp.ClientSession, proxy: Optional[str] = None, + proxy_auth: Optional[aiohttp.BasicAuth] = None, status: Optional[dummies.Status] = None, ) -> SimpleNamespace: if status: @@ -265,6 +276,7 @@ async def subreddit( session=session, endpoint=f"{self.connection.endpoints.subreddit}/{name}/about.json", proxy=proxy, + proxy_auth=proxy_auth, ) sanitised_response = self._parse.subreddit(response=response) @@ -272,12 +284,14 @@ async def subreddit( async def subreddits( self, - session: ClientSession, + session: aiohttp.ClientSession, kind: SUBREDDITS_KIND, limit: int, timeframe: TIMEFRAME, message: Optional[dummies.Message] = None, status: Optional[dummies.Status] = None, + proxy: Optional[str] = None, + proxy_auth: Optional[aiohttp.BasicAuth] = None, **kwargs: str, ) -> Union[List[Dict], Dict]: @@ -305,7 +319,8 @@ async def subreddits( subreddits = await self.connection.paginate_response( session=session, endpoint=endpoint, - proxy=kwargs.get("proxy"), + proxy=proxy, + proxy_auth=proxy_auth, params=params, parser=self._parse.subreddits, limit=limit, @@ -321,8 +336,9 @@ async def subreddits( async def user( self, name: str, - session: ClientSession, + session: aiohttp.ClientSession, proxy: Optional[str] = None, + proxy_auth: Optional[aiohttp.BasicAuth] = None, status: Optional[dummies.Status] = None, ) -> SimpleNamespace: if status: @@ -332,6 +348,7 @@ async def user( session=session, endpoint=f"{self.connection.endpoints.user}/{name}/about.json", proxy=proxy, + proxy_auth=proxy_auth, ) sanitised_response = self._parse.user(response=response) @@ -339,13 +356,14 @@ async def user( async def users( self, - session: ClientSession, + session: aiohttp.ClientSession, kind: USERS_KIND, limit: int, timeframe: TIMEFRAME, - proxy: Optional[str] = None, message: Optional[dummies.Message] = None, status: Optional[dummies.Status] = None, + proxy: Optional[str] = None, + proxy_auth: Optional[aiohttp.BasicAuth] = None, ) -> List[SimpleNamespace]: users_map = { @@ -367,6 +385,7 @@ async def users( session=session, endpoint=endpoint, proxy=proxy, + proxy_auth=proxy_auth, params=params, parser=self._parse.users, limit=limit, @@ -383,8 +402,9 @@ async def wiki_page( self, name: str, subreddit: str, - session: ClientSession, + session: aiohttp.ClientSession, proxy: Optional[str] = None, + proxy_auth: Optional[aiohttp.BasicAuth] = None, status: Optional[dummies.Status] = None, ) -> SimpleNamespace: if status: @@ -394,6 +414,7 @@ async def wiki_page( session=session, endpoint=f"{self.connection.endpoints.subreddit}/{subreddit}/wiki/{name}.json", proxy=proxy, + proxy_auth=proxy_auth, ) sanitised_response = self._parse.wiki_page(response=response)