Skip to content

Commit b46838d

Browse files
committed
[DataStore] closes #7
Make Firebase calls asynchronous by running them in a ThreadPoolExecutor: firebase/firebase-admin-python#104 (comment) This touches a lot of code, because a new formatter was introduced as well.
1 parent 90615f6 commit b46838d

File tree

7 files changed

+246
-120
lines changed

7 files changed

+246
-120
lines changed

DataStore.py

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,93 @@
11
from abc import ABC, abstractmethod
2+
from functools import partial
3+
from firebase_admin import credentials, firestore
24

35
import firebase_admin
4-
from firebase_admin import credentials, firestore
6+
import concurrent
7+
import asyncio
58

69

710
class DataStore(ABC):
811
def __init__(self):
912
pass
1013

1114
@abstractmethod
12-
def set(self, collection, document, val):
15+
async def set(self, collection, document, val):
16+
pass
17+
18+
@abstractmethod
19+
async def set_get_id(self, collection, val):
1320
pass
1421

1522
@abstractmethod
16-
def update(self, collection, document, val):
23+
async def update(self, collection, document, val):
1724
pass
1825

1926
@abstractmethod
20-
def add(self, collection, val):
27+
async def add(self, collection, val):
2128
pass
2229

2330
@abstractmethod
24-
def get(self, collection, document):
31+
async def get(self, collection, document):
2532
pass
2633

2734
@abstractmethod
28-
def delete(self, collection, document):
35+
async def delete(self, collection, document):
2936
pass
3037

3138

3239
class FirebaseDataStore(DataStore):
33-
def __init__(self, key_file, db_name):
40+
def __init__(self, key_file, db_name, loop):
3441
super().__init__()
3542
cred = credentials.Certificate(key_file)
3643

37-
firebase_admin.initialize_app(cred, {
38-
'databaseURL': f'https://{db_name}.firebaseio.com'
39-
})
44+
firebase_admin.initialize_app(
45+
cred, {'databaseURL': f'https://{db_name}.firebaseio.com'})
4046

4147
self.db = firestore.client()
48+
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
49+
self.loop = loop
50+
51+
async def set(self, collection, document, val):
52+
ref = self._get_doc_ref(collection, document)
53+
await self.loop.run_in_executor(self.executor, partial(ref.set, val))
54+
55+
async def set_get_id(self, collection, val):
56+
ref = self._get_doc_ref(collection, None)
57+
await self.loop.run_in_executor(self.executor, partial(ref.set, val))
58+
return ref.id
59+
60+
async def update(self, collection, document, val):
61+
ref = self._get_doc_ref(collection, document)
62+
await self.loop.run_in_executor(self.executor,
63+
partial(ref.update, val))
64+
65+
async def add(self, collection, val):
66+
ref = self._get_collection(collection)
67+
await self.loop.run_in_executor(self.executor, partial(ref.add, val))
68+
69+
async def get(self, collection, document=None):
70+
if document is None:
71+
ref = self._get_collection(collection)
72+
return await self.loop.run_in_executor(self.executor, ref.stream)
73+
else:
74+
ref = self._get_doc_ref(collection, document)
75+
return await self.loop.run_in_executor(self.executor, ref.get)
4276

43-
def set(self, collection, document, val):
44-
self._get_doc_ref(collection, document).set(val)
45-
46-
def set_get_id(self, collection, val):
47-
doc = self._get_doc_ref(collection, None)
48-
doc.set(val)
49-
return doc.id
50-
51-
def update(self, collection, document, val):
52-
self._get_doc_ref(collection, document).update(val)
53-
54-
def add(self, collection, val):
55-
self._get_collection(collection).add(val)
56-
57-
def get(self, collection, document=None):
58-
return self._get_collection(collection).stream() if document is None else self._get_doc_ref(collection,
59-
document).get()
60-
61-
def delete(self, collection, document=None):
77+
async def delete(self, collection, document=None):
6278
if document is not None:
63-
self._get_doc_ref(collection, document).delete()
79+
ref = self._get_doc_ref(collection, document)
80+
await self.loop.run_in_executor(self.executor, ref.delete)
6481
else:
6582
# implement batching later
6683
docs = self._get_collection(collection).stream()
6784
for doc in docs:
68-
doc.reference.delete()
85+
await self.loop.run_in_executor(self.executor,
86+
doc.reference.delete)
6987

70-
def query(self, collection, *query):
71-
return self._get_collection(collection).where(*query).stream()
88+
async def query(self, collection, *query):
89+
ref = self._get_collection(collection).where(*query)
90+
return await self.loop.run_in_executor(self.executor, ref.stream)
7291

7392
def _get_doc_ref(self, collection, document):
7493
return self._get_collection(collection).document(document)
@@ -83,6 +102,10 @@ def _get_collection(self, collection):
83102
config = configparser.ConfigParser()
84103
config.read('conf.ini')
85104

86-
firebase_ds = FirebaseDataStore(
87-
config['firebase']['key_file'], config['firebase']['db_name'])
88-
firebase_ds.add('jobs', {'func': 'somefunc', 'time': 234903284, 'args': ['arg1', 'arg2']})
105+
firebase_ds = FirebaseDataStore(config['firebase']['key_file'],
106+
config['firebase']['db_name'])
107+
firebase_ds.add('jobs', {
108+
'func': 'somefunc',
109+
'time': 234903284,
110+
'args': ['arg1', 'arg2']
111+
})

botw-bot.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,29 @@
1010
config.read('conf.ini')
1111
client = commands.Bot(command_prefix=config['discord']['command_prefix'])
1212
client.config = config
13-
client.database = FirebaseDataStore(
14-
config['firebase']['key_file'], config['firebase']['db_name'])
13+
client.db = FirebaseDataStore(config['firebase']['key_file'],
14+
config['firebase']['db_name'], client.loop)
1515
logger = logging.getLogger('discord')
1616
logger.setLevel(logging.INFO)
17-
handler = logging.FileHandler(
18-
filename='botw-bot.log', encoding='utf-8', mode='w')
19-
handler.setFormatter(logging.Formatter(
20-
'%(asctime)s:%(levelname)s:%(name)s: %(message)s'))
17+
handler = logging.FileHandler(filename='botw-bot.log',
18+
encoding='utf-8',
19+
mode='w')
20+
handler.setFormatter(
21+
logging.Formatter('%(asctime)s:%(levelname)s:%(name)s: %(message)s'))
2122
logger.addHandler(handler)
2223

2324
INITIAL_EXTENSIONS = [
24-
'cogs.BiasOfTheWeek',
25-
'cogs.Utilities',
26-
'cogs.Scheduler',
27-
'cogs.EmojiUtils',
28-
'cogs.Tags',
29-
'jishaku'
25+
'cogs.BiasOfTheWeek', 'cogs.Utilities', 'cogs.Scheduler',
26+
'cogs.EmojiUtils', 'cogs.Tags', 'jishaku'
3027
]
3128

3229

3330
@client.event
3431
async def on_ready():
3532
await client.change_presence(activity=discord.Game('with Bini'))
36-
logger.info(f"Logged in as {client.user}. Whitelisted servers: {config.items('whitelisted_servers')}")
33+
logger.info(
34+
f"Logged in as {client.user}. Whitelisted servers: {config.items('whitelisted_servers')}"
35+
)
3736

3837
for ext in INITIAL_EXTENSIONS:
3938
ext_logger = logging.getLogger(ext)
@@ -54,7 +53,9 @@ async def globally_block_dms(ctx):
5453

5554
@client.check
5655
async def whitelisted_server(ctx):
57-
server_ids = [int(server) for key, server in config.items('whitelisted_servers')]
56+
server_ids = [
57+
int(server) for key, server in config.items('whitelisted_servers')
58+
]
5859
return ctx.guild.id in server_ids
5960

6061

cogs/BiasOfTheWeek.py

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from cogs.Scheduler import Job
1111

12-
1312
logger = logging.getLogger(__name__)
1413

1514

@@ -28,13 +27,11 @@ def __str__(self):
2827
def __eq__(self, other):
2928
if not isinstance(other, Idol):
3029
return NotImplemented
31-
return str.lower(self.group) == str.lower(other.group) and str.lower(self.name) == str.lower(other.name)
30+
return str.lower(self.group) == str.lower(other.group) and str.lower(
31+
self.name) == str.lower(other.name)
3232

3333
def to_dict(self):
34-
return {
35-
'group': self.group,
36-
'name': self.name
37-
}
34+
return {'group': self.group, 'name': self.name}
3835

3936
@staticmethod
4037
def from_dict(source):
@@ -45,47 +42,65 @@ class BiasOfTheWeek(commands.Cog):
4542
def __init__(self, bot):
4643
self.bot = bot
4744
self.nominations = {}
48-
self.nominations_collection = self.bot.config['biasoftheweek']['nominations_collection']
45+
self.nominations_collection = self.bot.config['biasoftheweek'][
46+
'nominations_collection']
47+
48+
if self.bot.loop.is_running():
49+
asyncio.create_task(self._ainit())
50+
else:
51+
self.bot.loop.run_until_complete(self._ainit())
52+
53+
async def _ainit(self):
54+
_nominations = await self.bot.db.get(self.nominations_collection)
4955

50-
_nominations = self.bot.database.get(self.nominations_collection)
5156
for nomination in _nominations:
52-
self.nominations[self.bot.get_user(int(nomination.id))] = Idol.from_dict(nomination.to_dict())
57+
self.nominations[self.bot.get_user(int(
58+
nomination.id))] = Idol.from_dict(nomination.to_dict())
5359

54-
logger.info(f'Initial nominations from database: {self.nominations}')
60+
logger.info(f'Initial nominations from db: {self.nominations}')
5561

5662
@staticmethod
5763
def reaction_check(reaction, user, author, prompt_msg):
5864
return user == author and str(reaction.emoji) in [CHECK_EMOJI, CROSS_EMOJI] and \
5965
reaction.message.id == prompt_msg.id
6066

6167
@commands.command()
62-
async def nominate(self, ctx, group: commands.clean_content, name: commands.clean_content):
68+
async def nominate(self, ctx, group: commands.clean_content,
69+
name: commands.clean_content):
6370
idol = Idol(group, name)
6471

6572
if idol in self.nominations.values():
66-
await ctx.send(f'**{idol}** has already been nominated. Please nominate someone else.')
73+
await ctx.send(
74+
f'**{idol}** has already been nominated. Please nominate someone else.'
75+
)
6776
elif ctx.author in self.nominations.keys():
6877
old_idol = self.nominations[ctx.author]
69-
prompt_msg = await ctx.send(f'Your current nomination is **{old_idol}**. Do you want to override it?')
78+
prompt_msg = await ctx.send(
79+
f'Your current nomination is **{old_idol}**. Do you want to override it?'
80+
)
7081
await prompt_msg.add_reaction(CHECK_EMOJI)
7182
await prompt_msg.add_reaction(CROSS_EMOJI)
7283
try:
73-
reaction, user = await self.bot.wait_for('reaction_add', timeout=60.0,
74-
check=lambda reaction, user: self.reaction_check(reaction,
75-
user,
76-
ctx.author,
77-
prompt_msg))
84+
reaction, user = await self.bot.wait_for(
85+
'reaction_add',
86+
timeout=60.0,
87+
check=lambda reaction, user: self.reaction_check(
88+
reaction, user, ctx.author, prompt_msg))
7889
except asyncio.TimeoutError:
7990
pass
8091
else:
8192
await prompt_msg.delete()
8293
if reaction.emoji == CHECK_EMOJI:
8394
self.nominations[ctx.author] = idol
84-
self.bot.database.set(self.nominations_collection, str(ctx.author.id), idol.to_dict())
85-
await ctx.send(f'{ctx.author} nominates **{idol}** instead of **{old_idol}**.')
95+
await self.bot.db.set(self.nominations_collection,
96+
str(ctx.author.id), idol.to_dict())
97+
await ctx.send(
98+
f'{ctx.author} nominates **{idol}** instead of **{old_idol}**.'
99+
)
86100
else:
87101
self.nominations[ctx.author] = idol
88-
self.bot.database.set(self.nominations_collection, str(ctx.author.id), idol.to_dict())
102+
await self.bot.db.set(self.nominations_collection,
103+
str(ctx.author.id), idol.to_dict())
89104
await ctx.send(f'{ctx.author} nominates **{idol}**.')
90105

91106
@nominate.error
@@ -96,7 +111,7 @@ async def nominate_error(self, ctx, error):
96111
@commands.has_permissions(administrator=True)
97112
async def clear_nominations(self, ctx):
98113
self.nominations = {}
99-
self.bot.database.delete(self.nominations_collection)
114+
await self.bot.db.delete(self.nominations_collection)
100115
await ctx.message.add_reaction(CHECK_EMOJI)
101116

102117
@commands.command()
@@ -110,23 +125,43 @@ async def nominations(self, ctx):
110125
else:
111126
await ctx.send('So far, no idols have been nominated.')
112127

128+
@commands.command()
129+
async def db_noms(self, ctx):
130+
embed = discord.Embed(title='Bias of the Week nominations')
131+
nominations = {}
132+
_nominations = await self.bot.db.get(self.nominations_collection)
133+
for nomination in _nominations:
134+
nominations[self.bot.get_user(int(
135+
nomination.id))] = Idol.from_dict(nomination.to_dict())
136+
137+
for key, value in nominations.items():
138+
embed.add_field(name=key, value=value)
139+
140+
await ctx.send(embed=embed)
141+
113142
@commands.command(name='pickwinner')
114143
@commands.has_permissions(administrator=True)
115-
async def pick_winner(self, ctx, silent: bool = False, fast_assign: bool = False):
144+
async def pick_winner(self,
145+
ctx,
146+
silent: bool = False,
147+
fast_assign: bool = False):
116148
member, pick = random.choice(list(self.nominations.items()))
117149

118150
# Assign BotW winner role on next wednesday at 00:00 UTC
119151
now = pendulum.now('Europe/London')
120-
assign_date = now.add(seconds=120) if fast_assign else now.next(
121-
pendulum.WEDNESDAY)
152+
assign_date = now.add(
153+
seconds=120) if fast_assign else now.next(pendulum.WEDNESDAY)
122154

123155
await ctx.send(
124156
f"""Bias of the Week ({now.week_of_year}-{now.year}): {member if silent else member.mention}\'s pick **{pick}**.
125-
You will be assigned the role *{self.bot.config['biasoftheweek']['winner_role_name']}* at {assign_date.to_cookie_string()}.""")
157+
You will be assigned the role *{self.bot.config['biasoftheweek']['winner_role_name']}* at {assign_date.to_cookie_string()}."""
158+
)
126159

127160
scheduler = self.bot.get_cog('Scheduler')
128161
if scheduler is not None:
129-
await scheduler.add_job(Job('assign_winner_role', [ctx.guild.id, member.id], assign_date.float_timestamp))
162+
await scheduler.add_job(
163+
Job('assign_winner_role', [ctx.guild.id, member.id],
164+
assign_date.float_timestamp))
130165

131166
@pick_winner.error
132167
async def pick_winner_error(self, ctx, error):

cogs/EmojiUtils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ class EmojiUtils(commands.Cog):
2020

2121
def __init__(self, bot):
2222
self.bot = bot
23-
self.bot.add_listener(self.on_guild_emojis_update, 'on_guild_emojis_update')
24-
self.emoji_channel = self.bot.get_channel(int(self.bot.config['emojiutils']['emoji_channel']))
23+
self.bot.add_listener(self.on_guild_emojis_update,
24+
'on_guild_emojis_update')
25+
self.emoji_channel = self.bot.get_channel(
26+
int(self.bot.config['emojiutils']['emoji_channel']))
2527

2628
@commands.group(name='emoji')
2729
@commands.has_permissions(administrator=True)
@@ -41,16 +43,22 @@ async def emoji_list_error(self, ctx, error):
4143
async def on_guild_emojis_update(self, guild, before, after):
4244
# delete old messages containing emoji
4345
# need to use Message.delete to be able to delete messages older than 14 days
44-
async for message in self.emoji_channel.history(limit=EmojiUtils.DELETE_LIMIT):
46+
async for message in self.emoji_channel.history(
47+
limit=EmojiUtils.DELETE_LIMIT):
4548
await message.delete()
4649

4750
# get emoji that were added in the last 10 minutes
48-
recent_emoji = [emoji for emoji in after if (
49-
time.time() - discord.utils.snowflake_time(emoji.id).timestamp()) < EmojiUtils.NEW_EMOTE_THRESHOLD]
51+
recent_emoji = [
52+
emoji for emoji in after
53+
if (time.time() -
54+
discord.utils.snowflake_time(emoji.id).timestamp()
55+
) < EmojiUtils.NEW_EMOTE_THRESHOLD
56+
]
5057

5158
emoji_sorted = sorted(after, key=lambda e: e.name)
5259
for emoji_chunk in chunker(emoji_sorted, EmojiUtils.SPLIT_MSG_AFTER):
5360
await self.emoji_channel.send(''.join(str(e) for e in emoji_chunk))
5461

5562
if len(recent_emoji) > 0:
56-
await self.emoji_channel.send(f"Newly added: {''.join(str(e) for e in recent_emoji)}")
63+
await self.emoji_channel.send(
64+
f"Newly added: {''.join(str(e) for e in recent_emoji)}")

0 commit comments

Comments
 (0)