Skip to content

Commit 3c08548

Browse files
committed
Use the same session for the bot and graphql
This also moves the creation of the connector to the async start_bot func
1 parent 31fe391 commit 3c08548

File tree

3 files changed

+35
-29
lines changed

3 files changed

+35
-29
lines changed

bot/__main__.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import asyncio
22
import inspect
33
import pkgutil
4+
import socket
45
from typing import Iterator, NoReturn
56

7+
import aiohttp
68
from discord import AllowedMentions, Intents
79

810
from . import settings
@@ -31,16 +33,24 @@ def on_error(name: str) -> NoReturn:
3133

3234
async def start_bot() -> None:
3335
"""Load in extensions and start running the bot."""
34-
bot = Friendo(
35-
command_prefix=settings.COMMAND_PREFIX, help_command=None, intents=Intents.all(),
36-
allowed_mentions=AllowedMentions(everyone=False),
36+
resolver = aiohttp.AsyncResolver()
37+
connector = aiohttp.TCPConnector(
38+
resolver=resolver,
39+
family=socket.AF_INET,
3740
)
38-
39-
for cog in _get_cogs():
40-
await bot.load_extension(cog)
41-
42-
async with bot:
43-
await bot.start(settings.TOKEN)
41+
async with aiohttp.ClientSession(connector=connector) as session:
42+
bot = Friendo(
43+
command_prefix=settings.COMMAND_PREFIX, help_command=None, intents=Intents.all(),
44+
allowed_mentions=AllowedMentions(everyone=False),
45+
session=session,
46+
connector=connector,
47+
resolver=resolver,
48+
)
49+
for cog in _get_cogs():
50+
await bot.load_extension(cog)
51+
52+
async with bot:
53+
await bot.start(settings.TOKEN)
4454

4555
if __name__ == "__main__":
4656
asyncio.run(start_bot())

bot/bot.py

+13-14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import asyncio
21
import logging
3-
import socket
42

53
import aiohttp
64
from discord.ext.commands import Bot, CommandError, Context
@@ -15,23 +13,24 @@
1513
class Friendo(Bot):
1614
"""Base Class for the discord bot."""
1715

18-
def __init__(self, *args, **kwargs) -> None:
16+
def __init__(
17+
self,
18+
session: aiohttp.ClientSession,
19+
connector: aiohttp.TCPConnector,
20+
resolver: aiohttp.AsyncResolver,
21+
*args,
22+
**kwargs,
23+
) -> None:
1924
super().__init__(*args, **kwargs)
2025

21-
# Setting the loop.
22-
self.loop = asyncio.get_event_loop()
23-
24-
self._resolver = aiohttp.AsyncResolver()
25-
self._connector = aiohttp.TCPConnector(
26-
resolver=self._resolver,
27-
family=socket.AF_INET,
28-
)
2926
# Client.login() will call HTTPClient.static_login() which will create a session using
3027
# this connector attribute.
31-
self.http.connector = self._connector
28+
self.http.connector = connector
3229

33-
self.session = aiohttp.ClientSession(connector=self._connector)
34-
self.graphql = GraphQLClient(connector=self._connector)
30+
self.session = session
31+
self._connector = connector
32+
self._resolver = resolver
33+
self.graphql = GraphQLClient(session=session)
3534

3635
@staticmethod
3736
async def on_ready() -> None:

bot/graphql.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def filter(self, record: logging.LogRecord) -> bool:
2323
class GraphQLClient:
2424
"""Friendo GraphQL API wrapper."""
2525

26-
def __init__(self, **session_kwargs):
27-
self.session = aiohttp.ClientSession(raise_for_status=True, **session_kwargs)
26+
def __init__(self, session: aiohttp.ClientSession):
27+
self.session = session
2828
self.token = None
2929
self.url = settings.FRIENDO_API_URL
3030
self.headers = None
@@ -44,10 +44,9 @@ async def refresh_token(self) -> None:
4444
)
4545
variables = {
4646
"username": settings.FRIENDO_API_USER,
47-
"password": settings.FRIENDO_API_PASS
47+
"password": settings.FRIENDO_API_PASS,
4848
}
4949
resp = await self._post(json={"query": query, "variables": variables})
50-
5150
self.token = resp["data"]["login"]["token"]
5251
self.headers = {
5352
"Authorization": f"Bearer {self.token}"
@@ -78,7 +77,5 @@ async def _post(self, **kwargs) -> dict:
7877
"""Make a GraphQL API POST call."""
7978
async with self.session.post(self.url, headers=self.headers, **kwargs) as resp:
8079
resp = await resp.json()
81-
8280
log.info(resp)
83-
8481
return resp

0 commit comments

Comments
 (0)