Skip to content

Commit

Permalink
Use typing.Protocol instead of abc.ABCMeta
Browse files Browse the repository at this point in the history
  • Loading branch information
Gobot1234 authored Apr 4, 2021
1 parent fe54b3c commit 34ab772
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 100 deletions.
112 changes: 53 additions & 59 deletions discord/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
DEALINGS IN THE SOFTWARE.
"""

import abc
from __future__ import annotations

import sys
import copy
import asyncio
from typing import TYPE_CHECKING, Optional, Protocol, runtime_checkable

from .iterators import HistoryIterator
from .context_managers import Typing
Expand All @@ -39,13 +41,22 @@
from .voice_client import VoiceClient, VoiceProtocol
from . import utils

if TYPE_CHECKING:
from datetime import datetime

from .user import ClientUser


class _Undefined:
def __repr__(self):
return 'see-below'


_undefined = _Undefined()

class Snowflake(metaclass=abc.ABCMeta):

@runtime_checkable
class Snowflake(Protocol):
"""An ABC that details the common operations on a Discord model.
Almost all :ref:`Discord models <discord_api_models>` meet this
Expand All @@ -60,27 +71,16 @@ class Snowflake(metaclass=abc.ABCMeta):
The model's unique ID.
"""
__slots__ = ()
id: int

@property
@abc.abstractmethod
def created_at(self):
def created_at(self) -> datetime:
""":class:`datetime.datetime`: Returns the model's creation time as a naive datetime in UTC."""
raise NotImplementedError

@classmethod
def __subclasshook__(cls, C):
if cls is Snowflake:
mro = C.__mro__
for attr in ('created_at', 'id'):
for base in mro:
if attr in base.__dict__:
break
else:
return NotImplemented
return True
return NotImplemented

class User(metaclass=abc.ABCMeta):
@runtime_checkable
class User(Snowflake, Protocol):
"""An ABC that details the common operations on a Discord user.
The following implement this ABC:
Expand All @@ -104,35 +104,24 @@ class User(metaclass=abc.ABCMeta):
"""
__slots__ = ()

name: str
discriminator: str
avatar: Optional[str]
bot: bool

@property
@abc.abstractmethod
def display_name(self):
def display_name(self) -> str:
""":class:`str`: Returns the user's display name."""
raise NotImplementedError

@property
@abc.abstractmethod
def mention(self):
def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention the given user."""
raise NotImplementedError

@classmethod
def __subclasshook__(cls, C):
if cls is User:
if Snowflake.__subclasshook__(C) is NotImplemented:
return NotImplemented

mro = C.__mro__
for attr in ('display_name', 'mention', 'name', 'avatar', 'discriminator', 'bot'):
for base in mro:
if attr in base.__dict__:
break
else:
return NotImplemented
return True
return NotImplemented

class PrivateChannel(metaclass=abc.ABCMeta):
@runtime_checkable
class PrivateChannel(Snowflake, Protocol):
"""An ABC that details the common operations on a private Discord channel.
The following implement this ABC:
Expand All @@ -149,18 +138,8 @@ class PrivateChannel(metaclass=abc.ABCMeta):
"""
__slots__ = ()

@classmethod
def __subclasshook__(cls, C):
if cls is PrivateChannel:
if Snowflake.__subclasshook__(C) is NotImplemented:
return NotImplemented
me: ClientUser

mro = C.__mro__
for base in mro:
if 'me' in base.__dict__:
return True
return NotImplemented
return NotImplemented

class _Overwrites:
__slots__ = ('id', 'allow', 'deny', 'type')
Expand All @@ -179,7 +158,8 @@ def _asdict(self):
'type': self.type,
}

class GuildChannel:

class GuildChannel(Protocol):
"""An ABC that details the common operations on a Discord guild channel.
The following implement this ABC:
Expand All @@ -190,6 +170,11 @@ class GuildChannel:
This ABC must also implement :class:`~discord.abc.Snowflake`.
Note
----
This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
checks.
Attributes
-----------
name: :class:`str`
Expand Down Expand Up @@ -826,14 +811,13 @@ async def move(self, **kwargs):
lock_permissions = kwargs.get('sync_permissions', False)
reason = kwargs.get('reason')
for index, channel in enumerate(channels):
d = { 'id': channel.id, 'position': index }
d = {'id': channel.id, 'position': index}
if parent_id is not ... and channel.id == self.id:
d.update(parent_id=parent_id, lock_permissions=lock_permissions)
payload.append(d)

await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason)


async def create_invite(self, *, reason=None, **fields):
"""|coro|
Expand Down Expand Up @@ -908,7 +892,8 @@ async def invites(self):

return result

class Messageable(metaclass=abc.ABCMeta):

class Messageable(Protocol):
"""An ABC that details the common operations on a model that can send messages.
The following implement this ABC:
Expand All @@ -919,11 +904,16 @@ class Messageable(metaclass=abc.ABCMeta):
- :class:`~discord.User`
- :class:`~discord.Member`
- :class:`~discord.ext.commands.Context`
Note
----
This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
checks.
"""

__slots__ = ()

@abc.abstractmethod
async def _get_channel(self):
raise NotImplementedError

Expand Down Expand Up @@ -1060,8 +1050,8 @@ async def send(self, content=None, *, tts=False, embed=None, file=None,
f.close()
else:
data = await state.http.send_message(channel.id, content, tts=tts, embed=embed,
nonce=nonce, allowed_mentions=allowed_mentions,
message_reference=reference)
nonce=nonce, allowed_mentions=allowed_mentions,
message_reference=reference)

ret = state.create_message(channel=channel, data=data)
if delete_after is not None:
Expand Down Expand Up @@ -1213,21 +1203,25 @@ def history(self, *, limit=100, before=None, after=None, around=None, oldest_fir
"""
return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first)

class Connectable(metaclass=abc.ABCMeta):

class Connectable(Protocol):
"""An ABC that details the common operations on a channel that can
connect to a voice server.
The following implement this ABC:
- :class:`~discord.VoiceChannel`
Note
----
This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
checks.
"""
__slots__ = ()

@abc.abstractmethod
def _get_voice_client_key(self):
raise NotImplementedError

@abc.abstractmethod
def _get_voice_state_pair(self):
raise NotImplementedError

Expand Down Expand Up @@ -1286,6 +1280,6 @@ async def connect(self, *, timeout=60.0, reconnect=True, cls=VoiceClient):
except Exception:
# we don't care if disconnect failed because connection failed
pass
raise # re-raise
raise # re-raise

return voice
Loading

0 comments on commit 34ab772

Please sign in to comment.