diff --git a/discord/abc.py b/discord/abc.py index c3735f6c29b1..4930ae31d950 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -22,10 +22,12 @@ DEALINGS IN THE SOFTWARE. """ -import abc +from __future__ import annotations + import sys import copy import asyncio +from typing import TYPE_CHECKING, Optional, Protocol, runtime_checkable from .iterators import HistoryIterator from .context_managers import Typing @@ -39,13 +41,22 @@ from .voice_client import VoiceClient, VoiceProtocol from . import utils +if TYPE_CHECKING: + from datetime import datetime + + from .user import ClientUser + + class _Undefined: def __repr__(self): return 'see-below' + _undefined = _Undefined() -class Snowflake(metaclass=abc.ABCMeta): + +@runtime_checkable +class Snowflake(Protocol): """An ABC that details the common operations on a Discord model. Almost all :ref:`Discord models ` meet this @@ -60,27 +71,16 @@ class Snowflake(metaclass=abc.ABCMeta): The model's unique ID. """ __slots__ = () + id: int @property - @abc.abstractmethod - def created_at(self): + def created_at(self) -> datetime: """:class:`datetime.datetime`: Returns the model's creation time as a naive datetime in UTC.""" raise NotImplementedError - @classmethod - def __subclasshook__(cls, C): - if cls is Snowflake: - mro = C.__mro__ - for attr in ('created_at', 'id'): - for base in mro: - if attr in base.__dict__: - break - else: - return NotImplemented - return True - return NotImplemented -class User(metaclass=abc.ABCMeta): +@runtime_checkable +class User(Snowflake, Protocol): """An ABC that details the common operations on a Discord user. The following implement this ABC: @@ -104,35 +104,24 @@ class User(metaclass=abc.ABCMeta): """ __slots__ = () + name: str + discriminator: str + avatar: Optional[str] + bot: bool + @property - @abc.abstractmethod - def display_name(self): + def display_name(self) -> str: """:class:`str`: Returns the user's display name.""" raise NotImplementedError @property - @abc.abstractmethod - def mention(self): + def mention(self) -> str: """:class:`str`: Returns a string that allows you to mention the given user.""" raise NotImplementedError - @classmethod - def __subclasshook__(cls, C): - if cls is User: - if Snowflake.__subclasshook__(C) is NotImplemented: - return NotImplemented - - mro = C.__mro__ - for attr in ('display_name', 'mention', 'name', 'avatar', 'discriminator', 'bot'): - for base in mro: - if attr in base.__dict__: - break - else: - return NotImplemented - return True - return NotImplemented -class PrivateChannel(metaclass=abc.ABCMeta): +@runtime_checkable +class PrivateChannel(Snowflake, Protocol): """An ABC that details the common operations on a private Discord channel. The following implement this ABC: @@ -149,18 +138,8 @@ class PrivateChannel(metaclass=abc.ABCMeta): """ __slots__ = () - @classmethod - def __subclasshook__(cls, C): - if cls is PrivateChannel: - if Snowflake.__subclasshook__(C) is NotImplemented: - return NotImplemented + me: ClientUser - mro = C.__mro__ - for base in mro: - if 'me' in base.__dict__: - return True - return NotImplemented - return NotImplemented class _Overwrites: __slots__ = ('id', 'allow', 'deny', 'type') @@ -179,7 +158,8 @@ def _asdict(self): 'type': self.type, } -class GuildChannel: + +class GuildChannel(Protocol): """An ABC that details the common operations on a Discord guild channel. The following implement this ABC: @@ -190,6 +170,11 @@ class GuildChannel: This ABC must also implement :class:`~discord.abc.Snowflake`. + Note + ---- + This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass` + checks. + Attributes ----------- name: :class:`str` @@ -826,14 +811,13 @@ async def move(self, **kwargs): lock_permissions = kwargs.get('sync_permissions', False) reason = kwargs.get('reason') for index, channel in enumerate(channels): - d = { 'id': channel.id, 'position': index } + d = {'id': channel.id, 'position': index} if parent_id is not ... and channel.id == self.id: d.update(parent_id=parent_id, lock_permissions=lock_permissions) payload.append(d) await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason) - async def create_invite(self, *, reason=None, **fields): """|coro| @@ -908,7 +892,8 @@ async def invites(self): return result -class Messageable(metaclass=abc.ABCMeta): + +class Messageable(Protocol): """An ABC that details the common operations on a model that can send messages. The following implement this ABC: @@ -919,11 +904,16 @@ class Messageable(metaclass=abc.ABCMeta): - :class:`~discord.User` - :class:`~discord.Member` - :class:`~discord.ext.commands.Context` + + + Note + ---- + This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass` + checks. """ __slots__ = () - @abc.abstractmethod async def _get_channel(self): raise NotImplementedError @@ -1060,8 +1050,8 @@ async def send(self, content=None, *, tts=False, embed=None, file=None, f.close() else: data = await state.http.send_message(channel.id, content, tts=tts, embed=embed, - nonce=nonce, allowed_mentions=allowed_mentions, - message_reference=reference) + nonce=nonce, allowed_mentions=allowed_mentions, + message_reference=reference) ret = state.create_message(channel=channel, data=data) if delete_after is not None: @@ -1213,21 +1203,25 @@ def history(self, *, limit=100, before=None, after=None, around=None, oldest_fir """ return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first) -class Connectable(metaclass=abc.ABCMeta): + +class Connectable(Protocol): """An ABC that details the common operations on a channel that can connect to a voice server. The following implement this ABC: - :class:`~discord.VoiceChannel` + + Note + ---- + This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass` + checks. """ __slots__ = () - @abc.abstractmethod def _get_voice_client_key(self): raise NotImplementedError - @abc.abstractmethod def _get_voice_state_pair(self): raise NotImplementedError @@ -1286,6 +1280,6 @@ async def connect(self, *, timeout=60.0, reconnect=True, cls=VoiceClient): except Exception: # we don't care if disconnect failed because connection failed pass - raise # re-raise + raise # re-raise return voice diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 010acca90fc3..d5a30a83b3fa 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -22,14 +22,19 @@ DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + import re import inspect -import typing +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, Union, runtime_checkable import discord - from .errors import * +if TYPE_CHECKING: + from .context import Context + + __all__ = ( 'Converter', 'MemberConverter', @@ -54,6 +59,7 @@ 'Greedy', ) + def _get_from_guilds(bot, getter, argument): result = None for guild in bot.guilds: @@ -62,9 +68,13 @@ def _get_from_guilds(bot, getter, argument): return result return result + _utils_get = discord.utils.get +T = TypeVar("T") -class Converter: + +@runtime_checkable +class Converter(Protocol[T]): """The base class of custom converters that require the :class:`.Context` to be passed to be useful. @@ -75,7 +85,7 @@ class Converter: method to do its conversion logic. This method must be a :ref:`coroutine `. """ - async def convert(self, ctx, argument): + async def convert(self, ctx: Context, argument: str) -> T: """|coro| The method to override to do conversion logic. @@ -100,7 +110,7 @@ async def convert(self, ctx, argument): """ raise NotImplementedError('Derived classes need to implement this.') -class IDConverter(Converter): +class IDConverter(Converter[T]): def __init__(self): self._id_regex = re.compile(r'([0-9]{15,20})$') super().__init__() @@ -108,7 +118,7 @@ def __init__(self): def _get_id_match(self, argument): return self._id_regex.match(argument) -class MemberConverter(IDConverter): +class MemberConverter(IDConverter[discord.Member]): """Converts to a :class:`~discord.Member`. All lookups are via the local guild. If in a DM context, then the lookup @@ -194,7 +204,7 @@ async def convert(self, ctx, argument): return result -class UserConverter(IDConverter): +class UserConverter(IDConverter[discord.User]): """Converts to a :class:`~discord.User`. All lookups are via the global user cache. @@ -253,7 +263,7 @@ async def convert(self, ctx, argument): return result -class PartialMessageConverter(Converter): +class PartialMessageConverter(Converter[discord.PartialMessage], Generic[T]): """Converts to a :class:`discord.PartialMessage`. .. versionadded:: 1.7 @@ -284,7 +294,7 @@ async def convert(self, ctx, argument): raise ChannelNotFound(channel_id) return discord.PartialMessage(channel=channel, id=message_id) -class MessageConverter(PartialMessageConverter): +class MessageConverter(PartialMessageConverter[discord.Message]): """Converts to a :class:`discord.Message`. .. versionadded:: 1.1 @@ -313,7 +323,7 @@ async def convert(self, ctx, argument): except discord.Forbidden: raise ChannelNotReadable(channel) -class TextChannelConverter(IDConverter): +class TextChannelConverter(IDConverter[discord.TextChannel]): """Converts to a :class:`~discord.TextChannel`. All lookups are via the local guild. If in a DM context, then the lookup @@ -355,7 +365,7 @@ def check(c): return result -class VoiceChannelConverter(IDConverter): +class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): """Converts to a :class:`~discord.VoiceChannel`. All lookups are via the local guild. If in a DM context, then the lookup @@ -396,7 +406,7 @@ def check(c): return result -class StageChannelConverter(IDConverter): +class StageChannelConverter(IDConverter[discord.StageChannel]): """Converts to a :class:`~discord.StageChannel`. .. versionadded:: 1.7 @@ -436,7 +446,7 @@ def check(c): return result -class CategoryChannelConverter(IDConverter): +class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): """Converts to a :class:`~discord.CategoryChannel`. All lookups are via the local guild. If in a DM context, then the lookup @@ -478,7 +488,7 @@ def check(c): return result -class StoreChannelConverter(IDConverter): +class StoreChannelConverter(IDConverter[discord.StoreChannel]): """Converts to a :class:`~discord.StoreChannel`. All lookups are via the local guild. If in a DM context, then the lookup @@ -519,7 +529,7 @@ def check(c): return result -class ColourConverter(Converter): +class ColourConverter(Converter[discord.Colour]): """Converts to a :class:`~discord.Colour`. .. versionchanged:: 1.5 @@ -603,7 +613,7 @@ async def convert(self, ctx, argument): ColorConverter = ColourConverter -class RoleConverter(IDConverter): +class RoleConverter(IDConverter[discord.Role]): """Converts to a :class:`~discord.Role`. All lookups are via the local guild. If in a DM context, then the lookup @@ -633,12 +643,12 @@ async def convert(self, ctx, argument): raise RoleNotFound(argument) return result -class GameConverter(Converter): +class GameConverter(Converter[discord.Game]): """Converts to :class:`~discord.Game`.""" async def convert(self, ctx, argument): return discord.Game(name=argument) -class InviteConverter(Converter): +class InviteConverter(Converter[discord.Invite]): """Converts to a :class:`~discord.Invite`. This is done via an HTTP request using :meth:`.Bot.fetch_invite`. @@ -653,7 +663,7 @@ async def convert(self, ctx, argument): except Exception as exc: raise BadInviteArgument() from exc -class GuildConverter(IDConverter): +class GuildConverter(IDConverter[discord.Guild]): """Converts to a :class:`~discord.Guild`. The lookup strategy is as follows (in order): @@ -679,7 +689,7 @@ async def convert(self, ctx, argument): raise GuildNotFound(argument) return result -class EmojiConverter(IDConverter): +class EmojiConverter(IDConverter[discord.Emoji]): """Converts to a :class:`~discord.Emoji`. All lookups are done for the local guild first, if available. If that lookup @@ -722,7 +732,7 @@ async def convert(self, ctx, argument): return result -class PartialEmojiConverter(Converter): +class PartialEmojiConverter(Converter[discord.PartialEmoji]): """Converts to a :class:`~discord.PartialEmoji`. This is done by extracting the animated flag, name and ID from the emoji. @@ -743,7 +753,7 @@ async def convert(self, ctx, argument): raise PartialEmojiConversionFailure(argument) -class clean_content(Converter): +class clean_content(Converter[str]): """Converts the argument to mention scrubbed version of said content. @@ -775,7 +785,7 @@ async def convert(self, ctx, argument): if self.fix_channel_mentions and ctx.guild: def resolve_channel(id, *, _get=ctx.guild.get_channel): ch = _get(id) - return (f'<#{id}>'), ('#' + ch.name if ch else '#deleted-channel') + return f'<#{id}>', ('#' + ch.name if ch else '#deleted-channel') transformations.update(resolve_channel(channel) for channel in message.raw_channel_mentions) @@ -842,7 +852,7 @@ def __getitem__(self, params): if converter is str or converter is type(None) or converter is _Greedy: raise TypeError(f'Greedy[{converter.__name__}] is invalid.') - if getattr(converter, '__origin__', None) is typing.Union and type(None) in converter.__args__: + if getattr(converter, '__origin__', None) is Union and type(None) in converter.__args__: raise TypeError(f'Greedy[{converter!r}] is invalid.') return self.__class__(converter=converter) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index d3badff474e2..ec2e7deb7c7d 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -448,11 +448,6 @@ async def _actual_conversion(self, ctx, converter, argument, param): instance = converter() ret = await instance.convert(ctx, argument) return ret - else: - method = getattr(converter, 'convert', None) - if method is not None and inspect.ismethod(method): - ret = await method(ctx, argument) - return ret elif isinstance(converter, converters.Converter): ret = await converter.convert(ctx, argument) return ret diff --git a/docs/api.rst b/docs/api.rst index 4a19c2c74147..5ef63f58f2f8 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -2499,20 +2499,19 @@ interface, :meth:`WebhookAdapter.request`. Abstract Base Classes ----------------------- -An :term:`py:abstract base class` (also known as an ``abc``) is a class that models can inherit -to get their behaviour. The Python implementation of an :doc:`abc ` is -slightly different in that you can register them at run-time. **Abstract base classes cannot be instantiated**. -They are mainly there for usage with :func:`py:isinstance` and :func:`py:issubclass`\. +An :term:`abstract base class` (also known as an ``abc``) is a class that models can inherit +to get their behaviour. **Abstract base classes should not be instantiated**. +They are mainly there for usage with :func:`isinstance` and :func:`issubclass`\. -This library has a module related to abstract base classes, some of which are actually from the :doc:`abc ` standard -module, others which are not. +This library has a module related to abstract base classes, in which all the ABCs are subclasses of +:class:`typing.Protocol`. Snowflake ~~~~~~~~~~ .. attributetable:: discord.abc.Snowflake -.. autoclass:: discord.abc.Snowflake +.. autoclass:: discord.abc.Snowflake() :members: User @@ -2520,7 +2519,7 @@ User .. attributetable:: discord.abc.User -.. autoclass:: discord.abc.User +.. autoclass:: discord.abc.User() :members: PrivateChannel @@ -2528,7 +2527,7 @@ PrivateChannel .. attributetable:: discord.abc.PrivateChannel -.. autoclass:: discord.abc.PrivateChannel +.. autoclass:: discord.abc.PrivateChannel() :members: GuildChannel @@ -2536,7 +2535,7 @@ GuildChannel .. attributetable:: discord.abc.GuildChannel -.. autoclass:: discord.abc.GuildChannel +.. autoclass:: discord.abc.GuildChannel() :members: Messageable @@ -2544,7 +2543,7 @@ Messageable .. attributetable:: discord.abc.Messageable -.. autoclass:: discord.abc.Messageable +.. autoclass:: discord.abc.Messageable() :members: :exclude-members: history, typing @@ -2559,7 +2558,7 @@ Connectable .. attributetable:: discord.abc.Connectable -.. autoclass:: discord.abc.Connectable +.. autoclass:: discord.abc.Connectable() .. _discord_api_models: