From caf8b92496c6e5450ec064dac7a5a483ed73702c Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Fri, 9 Aug 2024 17:49:20 +0200 Subject: [PATCH 01/14] feat: add `Invite.type` (#1142) Co-authored-by: Victor <67214928+Victorsitou@users.noreply.github.com> --- changelog/1142.bugfix.rst | 1 + changelog/1142.feature.rst | 1 + disnake/audit_logs.py | 9 +++++++-- disnake/enums.py | 7 +++++++ disnake/guild.py | 11 +++++++++-- disnake/invite.py | 38 ++++++++++++++++++++++---------------- disnake/types/gateway.py | 3 ++- disnake/types/invite.py | 4 +++- docs/api/invites.rst | 21 +++++++++++++++++++++ 9 files changed, 73 insertions(+), 22 deletions(-) create mode 100644 changelog/1142.bugfix.rst create mode 100644 changelog/1142.feature.rst diff --git a/changelog/1142.bugfix.rst b/changelog/1142.bugfix.rst new file mode 100644 index 0000000000..4f10261ea3 --- /dev/null +++ b/changelog/1142.bugfix.rst @@ -0,0 +1 @@ +Support fetching invites with ``null`` channel (e.g. friend invites). diff --git a/changelog/1142.feature.rst b/changelog/1142.feature.rst new file mode 100644 index 0000000000..44d87009f1 --- /dev/null +++ b/changelog/1142.feature.rst @@ -0,0 +1 @@ +Add :attr:`Invite.type`. diff --git a/disnake/audit_logs.py b/disnake/audit_logs.py index c5ec6caf79..cc2948f9a3 100644 --- a/disnake/audit_logs.py +++ b/disnake/audit_logs.py @@ -64,6 +64,7 @@ DefaultReaction as DefaultReactionPayload, PermissionOverwrite as PermissionOverwritePayload, ) + from .types.invite import Invite as InvitePayload from .types.role import Role as RolePayload from .types.snowflake import Snowflake from .types.threads import ForumTag as ForumTagPayload @@ -799,15 +800,19 @@ def _convert_target_invite(self, target_id: int) -> Invite: # so figure out which change has the full invite data changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after - fake_payload = { + fake_payload: InvitePayload = { "max_age": changeset.max_age, "max_uses": changeset.max_uses, "code": changeset.code, "temporary": changeset.temporary, "uses": changeset.uses, + "type": 0, + "channel": None, } - obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) # type: ignore + obj = Invite( + state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel + ) try: obj.inviter = changeset.inviter except AttributeError: diff --git a/disnake/enums.py b/disnake/enums.py index 0597343ff7..56b06ca5a2 100644 --- a/disnake/enums.py +++ b/disnake/enums.py @@ -41,6 +41,7 @@ "ExpireBehavior", "StickerType", "StickerFormatType", + "InviteType", "InviteTarget", "VideoQualityMode", "ComponentType", @@ -606,6 +607,12 @@ def file_extension(self) -> str: } +class InviteType(Enum): + guild = 0 + group_dm = 1 + friend = 2 + + class InviteTarget(Enum): unknown = 0 stream = 1 diff --git a/disnake/guild.py b/disnake/guild.py index 301d926876..97ea1e80ac 100644 --- a/disnake/guild.py +++ b/disnake/guild.py @@ -3185,7 +3185,10 @@ async def invites(self) -> List[Invite]: data = await self._state.http.invites_from(self.id) result = [] for invite in data: - channel = self.get_channel(int(invite["channel"]["id"])) + if channel_data := invite.get("channel"): + channel = self.get_channel(int(channel_data["id"])) + else: + channel = None result.append(Invite(state=self._state, data=invite, guild=self, channel=channel)) return result @@ -4130,11 +4133,15 @@ async def vanity_invite(self, *, use_cached: bool = False) -> Optional[Invite]: # reliable or a thing anymore data = await self._state.http.get_invite(payload["code"]) - channel = self.get_channel(int(data["channel"]["id"])) + if channel_data := data.get("channel"): + channel = self.get_channel(int(channel_data["id"])) + else: + channel = None payload["temporary"] = False payload["max_uses"] = 0 payload["max_age"] = 0 payload["uses"] = payload.get("uses", 0) + payload["type"] = 0 return Invite(state=self._state, data=payload, guild=self, channel=channel) # TODO: use MISSING when async iterators get refactored diff --git a/disnake/invite.py b/disnake/invite.py index a936c832b1..545159dfb8 100644 --- a/disnake/invite.py +++ b/disnake/invite.py @@ -6,7 +6,7 @@ from .appinfo import PartialAppInfo from .asset import Asset -from .enums import ChannelType, InviteTarget, NSFWLevel, VerificationLevel, try_enum +from .enums import ChannelType, InviteTarget, InviteType, NSFWLevel, VerificationLevel, try_enum from .guild_scheduled_event import GuildScheduledEvent from .mixins import Hashable from .object import Object @@ -307,8 +307,13 @@ class Invite(Hashable): ---------- code: :class:`str` The URL fragment used for the invite. + type: :class:`InviteType` + The type of the invite. + + .. versionadded:: 2.10 + guild: Optional[Union[:class:`Guild`, :class:`Object`, :class:`PartialInviteGuild`]] - The guild the invite is for. Can be ``None`` if it's from a group direct message. + The guild the invite is for. Can be ``None`` if it's not a guild invite (see :attr:`type`). max_age: Optional[:class:`int`] How long before the invite expires in seconds. A value of ``0`` indicates that it doesn't expire. @@ -382,6 +387,7 @@ class Invite(Hashable): __slots__ = ( "max_age", "code", + "type", "guild", "created_at", "uses", @@ -412,6 +418,7 @@ def __init__( ) -> None: self._state: ConnectionState = state self.code: str = data["code"] + self.type: InviteType = try_enum(InviteType, data.get("type", 0)) self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get("guild"), guild) self.max_age: Optional[int] = data.get("max_age") @@ -481,15 +488,12 @@ def from_incomplete(cls, *, state: ConnectionState, data: InvitePayload) -> Self # If it's not cached, then it has to be a partial guild guild = PartialInviteGuild(state, guild_data, guild_id) - # todo: this is no longer true - # As far as I know, invites always need a channel - # So this should never raise. - channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel( - data=data["channel"], state=state - ) - if guild is not None and not isinstance(guild, PartialInviteGuild): - # Upgrade the partial data if applicable - channel = guild.get_channel(channel.id) or channel + channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None + if channel_data := data.get("channel"): + channel = PartialInviteChannel(data=channel_data, state=state) + if guild is not None and not isinstance(guild, PartialInviteGuild): + # Upgrade the partial data if applicable + channel = guild.get_channel(channel.id) or channel return cls(state=state, data=data, guild=guild, channel=channel) @@ -543,11 +547,13 @@ def __str__(self) -> str: return self.url def __repr__(self) -> str: - return ( - f"" - ) + s = f" int: return hash(self.code) diff --git a/disnake/types/gateway.py b/disnake/types/gateway.py index e2494848b7..9e81523d29 100644 --- a/disnake/types/gateway.py +++ b/disnake/types/gateway.py @@ -17,7 +17,7 @@ from .guild_scheduled_event import GuildScheduledEvent from .integration import BaseIntegration from .interactions import BaseInteraction, GuildApplicationCommandPermissions -from .invite import InviteTargetType +from .invite import InviteTargetType, InviteType from .member import MemberWithUser from .message import Message from .role import Role @@ -348,6 +348,7 @@ class InviteCreateEvent(TypedDict): target_user: NotRequired[User] target_application: NotRequired[PartialAppInfo] temporary: bool + type: InviteType uses: int # always 0 diff --git a/disnake/types/invite.py b/disnake/types/invite.py index 93e573cfa5..b1f4ac4b63 100644 --- a/disnake/types/invite.py +++ b/disnake/types/invite.py @@ -12,6 +12,7 @@ from .guild_scheduled_event import GuildScheduledEvent from .user import PartialUser +InviteType = Literal[0, 1, 2] InviteTargetType = Literal[1, 2] @@ -30,8 +31,9 @@ class _InviteMetadata(TypedDict, total=False): class Invite(_InviteMetadata): code: str + type: InviteType guild: NotRequired[InviteGuild] - channel: InviteChannel + channel: Optional[InviteChannel] inviter: NotRequired[PartialUser] target_type: NotRequired[InviteTargetType] target_user: NotRequired[PartialUser] diff --git a/docs/api/invites.rst b/docs/api/invites.rst index 729f91272a..aa58074c2f 100644 --- a/docs/api/invites.rst +++ b/docs/api/invites.rst @@ -37,6 +37,27 @@ PartialInviteChannel Enumerations ------------ +InviteType +~~~~~~~~~~ + +.. class:: InviteType + + Represents the type of an invite. + + .. versionadded:: 2.10 + + .. attribute:: guild + + Represents an invite to a guild. + + .. attribute:: group_dm + + Represents an invite to a group channel. + + .. attribute:: friend + + Represents a friend invite. + InviteTarget ~~~~~~~~~~~~ From 409140f65f02a7a67696fcf050552272ca42cd20 Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Sat, 10 Aug 2024 15:10:56 +0200 Subject: [PATCH 02/14] docs(channel): clarify `ForumChannel.create_thread` return type (#1215) Co-authored-by: Victor <67214928+Victorsitou@users.noreply.github.com> --- disnake/channel.py | 10 ++++++---- docs/api/channels.rst | 18 ++++++++++++++++++ docs/api/guilds.rst | 4 ++-- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/disnake/channel.py b/disnake/channel.py index 8eaf8dd794..1f84e702b3 100644 --- a/disnake/channel.py +++ b/disnake/channel.py @@ -3494,7 +3494,7 @@ async def create_thread( ) -> ThreadWithMessage: """|coro| - Creates a thread in this channel. + Creates a thread (with an initial message) in this channel. You must have the :attr:`~Permissions.create_forum_threads` permission to do this. @@ -3507,6 +3507,10 @@ async def create_thread( .. versionchanged:: 2.6 The ``content`` parameter is no longer required. + .. note:: + Unlike :meth:`TextChannel.create_thread`, + this **returns a tuple** with both the created **thread and message**. + Parameters ---------- name: :class:`str` @@ -3583,10 +3587,8 @@ async def create_thread( Returns ------- - Tuple[:class:`Thread`, :class:`Message`] + :class:`ThreadWithMessage` A :class:`~typing.NamedTuple` with the newly created thread and the message sent in it. - - These values can also be accessed through the ``thread`` and ``message`` fields. """ from .message import Message from .webhook.async_ import handle_message_parameters_dict diff --git a/docs/api/channels.rst b/docs/api/channels.rst index faa709aa19..3086583883 100644 --- a/docs/api/channels.rst +++ b/docs/api/channels.rst @@ -182,6 +182,24 @@ ForumTag :members: :inherited-members: +ThreadWithMessage +~~~~~~~~~~~~~~~~~ + +.. class:: ThreadWithMessage + + A :class:`~typing.NamedTuple` which represents a thread and message returned from :meth:`ForumChannel.create_thread`. + + .. attribute:: thread + + The created thread. + + :type: :class:`Thread` + .. attribute:: message + + The initial message in the thread. + + :type: :class:`Message` + Enumerations ------------ diff --git a/docs/api/guilds.rst b/docs/api/guilds.rst index 7f59fb37c7..614ad3f355 100644 --- a/docs/api/guilds.rst +++ b/docs/api/guilds.rst @@ -56,7 +56,7 @@ BanEntry .. class:: BanEntry - A namedtuple which represents a ban returned from :meth:`~Guild.bans`. + A :class:`~typing.NamedTuple` which represents a ban returned from :meth:`~Guild.bans`. .. attribute:: reason @@ -74,7 +74,7 @@ BulkBanResult .. class:: BulkBanResult - A namedtuple which represents the successful and failed bans returned from :meth:`~Guild.bulk_ban`. + A :class:`~typing.NamedTuple` which represents the successful and failed bans returned from :meth:`~Guild.bulk_ban`. .. versionadded:: 2.10 From 47f28721ca5eb8c41d7a327342edab5d9a6fa79f Mon Sep 17 00:00:00 2001 From: Snipy7374 <100313469+Snipy7374@users.noreply.github.com> Date: Sun, 11 Aug 2024 14:41:53 +0200 Subject: [PATCH 03/14] feat(client): implement new sticker pack endpoint (#1222) Signed-off-by: Snipy7374 <100313469+Snipy7374@users.noreply.github.com> Co-authored-by: shiftinv <8530778+shiftinv@users.noreply.github.com> --- changelog/1221.feature.rst | 1 + disnake/client.py | 27 +++++++++++++++++++++++++++ disnake/http.py | 3 +++ 3 files changed, 31 insertions(+) create mode 100644 changelog/1221.feature.rst diff --git a/changelog/1221.feature.rst b/changelog/1221.feature.rst new file mode 100644 index 0000000000..cf252ab1dc --- /dev/null +++ b/changelog/1221.feature.rst @@ -0,0 +1 @@ +Add new :meth:`.Client.fetch_sticker_pack` method. diff --git a/disnake/client.py b/disnake/client.py index 156720b161..80b3d67c65 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -2512,6 +2512,33 @@ async def fetch_sticker(self, sticker_id: int, /) -> Union[StandardSticker, Guil cls, _ = _sticker_factory(data["type"]) # type: ignore return cls(state=self._connection, data=data) # type: ignore + async def fetch_sticker_pack(self, pack_id: int, /) -> StickerPack: + """|coro| + + Retrieves a :class:`.StickerPack` with the given ID. + + .. versionadded:: 2.10 + + Parameters + ---------- + pack_id: :class:`int` + The ID of the sticker pack to retrieve. + + Raises + ------ + HTTPException + Retrieving the sticker pack failed. + NotFound + Invalid sticker pack ID. + + Returns + ------- + :class:`.StickerPack` + The sticker pack you requested. + """ + data = await self.http.get_sticker_pack(pack_id) + return StickerPack(state=self._connection, data=data) + async def fetch_sticker_packs(self) -> List[StickerPack]: """|coro| diff --git a/disnake/http.py b/disnake/http.py index f10cd3fdd8..b8e9786f87 100644 --- a/disnake/http.py +++ b/disnake/http.py @@ -1562,6 +1562,9 @@ def estimate_pruned_members( def get_sticker(self, sticker_id: Snowflake) -> Response[sticker.Sticker]: return self.request(Route("GET", "/stickers/{sticker_id}", sticker_id=sticker_id)) + def get_sticker_pack(self, pack_id: Snowflake) -> Response[sticker.StickerPack]: + return self.request(Route("GET", "/sticker-packs/{pack_id}", pack_id=pack_id)) + def list_sticker_packs(self) -> Response[sticker.ListStickerPacks]: return self.request(Route("GET", "/sticker-packs")) From 050220b1a10d007c0522697359618ec911d06cf7 Mon Sep 17 00:00:00 2001 From: Snipy7374 <100313469+Snipy7374@users.noreply.github.com> Date: Tue, 13 Aug 2024 19:43:15 +0200 Subject: [PATCH 04/14] docs: clarify AuditLogAction.overwrite_create (#1227) --- changelog/1180.doc.rst | 1 + docs/api/audit_logs.rst | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 changelog/1180.doc.rst diff --git a/changelog/1180.doc.rst b/changelog/1180.doc.rst new file mode 100644 index 0000000000..1ef07bb612 --- /dev/null +++ b/changelog/1180.doc.rst @@ -0,0 +1 @@ +Adding some clarifying documentation around the type of :attr:`AuditLogEntry.extra` when the action is :attr:`~AuditLogAction.overwrite_create`. diff --git a/docs/api/audit_logs.rst b/docs/api/audit_logs.rst index f1a65e434c..29c03d52ab 100644 --- a/docs/api/audit_logs.rst +++ b/docs/api/audit_logs.rst @@ -861,9 +861,8 @@ AuditLogAction When this is the action, the type of :attr:`~AuditLogEntry.extra` is either a :class:`Role` or :class:`Member`. If the object is not found - then it is a :class:`Object` with an ID being filled, a name, and a - ``type`` attribute set to either ``'role'`` or ``'member'`` to help - dictate what type of ID it is. + then it is a :class:`Object` with an ID being filled, additionally if the object + refers to a role then the :class:`Object` has also a ``name`` attribute. Possible attributes for :class:`AuditLogDiff`: From 8387fa642bf56da2f9396fe76e96538812d72107 Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:50:11 +0200 Subject: [PATCH 05/14] docs: fix `Intents.*_typing` description (#1231) --- disnake/flags.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/disnake/flags.py b/disnake/flags.py index ce9752f011..1b26493892 100644 --- a/disnake/flags.py +++ b/disnake/flags.py @@ -1443,7 +1443,7 @@ def reactions(self): @flag_value def guild_reactions(self): - """:class:`bool`: Whether guild message reaction related events are enabled. + """:class:`bool`: Whether guild reaction related events are enabled. See also :attr:`dm_reactions` for DMs or :attr:`reactions` for both. @@ -1499,7 +1499,7 @@ def typing(self): @flag_value def guild_typing(self): - """:class:`bool`: Whether guild and direct message typing related events are enabled. + """:class:`bool`: Whether guild typing related events are enabled. See also :attr:`dm_typing` for DMs or :attr:`typing` for both. @@ -1513,7 +1513,7 @@ def guild_typing(self): @flag_value def dm_typing(self): - """:class:`bool`: Whether guild and direct message typing related events are enabled. + """:class:`bool`: Whether direct message typing related events are enabled. See also :attr:`guild_typing` for guilds or :attr:`typing` for both. From 2ebbd6073d832705df95a02f26a33f4476763843 Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:00:02 +0200 Subject: [PATCH 06/14] feat(app): add guild/user install count fields (#1220) --- changelog/1220.feature.rst | 1 + disnake/appinfo.py | 19 +++++++++++++++---- disnake/types/appinfo.py | 2 ++ 3 files changed, 18 insertions(+), 4 deletions(-) create mode 100644 changelog/1220.feature.rst diff --git a/changelog/1220.feature.rst b/changelog/1220.feature.rst new file mode 100644 index 0000000000..b321a969f2 --- /dev/null +++ b/changelog/1220.feature.rst @@ -0,0 +1 @@ +Add :attr:`AppInfo.approximate_guild_count` and :attr:`AppInfo.approximate_user_install_count`. diff --git a/disnake/appinfo.py b/disnake/appinfo.py index 2cc80b1956..15468f1eb5 100644 --- a/disnake/appinfo.py +++ b/disnake/appinfo.py @@ -98,8 +98,7 @@ class AppInfo: .. versionadded:: 1.3 guild_id: Optional[:class:`int`] - If this application is a game sold on Discord, - this field will be the guild to which it has been linked to. + The ID of the guild associated with the application, if any. .. versionadded:: 1.3 @@ -151,6 +150,15 @@ class AppInfo: in the guild role verification configuration. .. versionadded:: 2.8 + approximate_guild_count: :class:`int` + The approximate number of guilds the application is installed to. + + .. versionadded:: 2.10 + approximate_user_install_count: :class:`int` + The approximate number of users that have installed the application + (for user-installable apps). + + .. versionadded:: 2.10 """ __slots__ = ( @@ -177,6 +185,8 @@ class AppInfo: "install_params", "custom_install_url", "role_connections_verification_url", + "approximate_guild_count", + "approximate_user_install_count", ) def __init__(self, state: ConnectionState, data: AppInfoPayload) -> None: @@ -218,6 +228,8 @@ def __init__(self, state: ConnectionState, data: AppInfoPayload) -> None: self.role_connections_verification_url: Optional[str] = data.get( "role_connections_verification_url" ) + self.approximate_guild_count: int = data.get("approximate_guild_count", 0) + self.approximate_user_install_count: int = data.get("approximate_user_install_count", 0) def __repr__(self) -> str: return ( @@ -245,8 +257,7 @@ def cover_image(self) -> Optional[Asset]: @property def guild(self) -> Optional[Guild]: - """Optional[:class:`Guild`]: If this application is a game sold on Discord, - this field will be the guild to which it has been linked + """Optional[:class:`Guild`]: The guild associated with the application, if any. .. versionadded:: 1.3 """ diff --git a/disnake/types/appinfo.py b/disnake/types/appinfo.py index 35507dc82d..9df725a043 100644 --- a/disnake/types/appinfo.py +++ b/disnake/types/appinfo.py @@ -42,6 +42,8 @@ class AppInfo(BaseAppInfo): install_params: NotRequired[InstallParams] custom_install_url: NotRequired[str] role_connections_verification_url: NotRequired[str] + approximate_guild_count: NotRequired[int] + approximate_user_install_count: NotRequired[int] class PartialAppInfo(BaseAppInfo, total=False): From cebfb89fd6c9a1711a672b405c7337e5f838bf41 Mon Sep 17 00:00:00 2001 From: Snipy7374 <100313469+Snipy7374@users.noreply.github.com> Date: Tue, 20 Aug 2024 19:55:13 +0200 Subject: [PATCH 07/14] feat(message): add Attachment.title (#1219) Signed-off-by: Snipy7374 <100313469+Snipy7374@users.noreply.github.com> Co-authored-by: shiftinv <8530778+shiftinv@users.noreply.github.com> --- changelog/1218.feature.rst | 1 + disnake/message.py | 10 ++++++++++ disnake/types/message.py | 1 + 3 files changed, 12 insertions(+) create mode 100644 changelog/1218.feature.rst diff --git a/changelog/1218.feature.rst b/changelog/1218.feature.rst new file mode 100644 index 0000000000..2dc2e68c48 --- /dev/null +++ b/changelog/1218.feature.rst @@ -0,0 +1 @@ +Add new :attr:`Attachment.title` attribute. diff --git a/disnake/message.py b/disnake/message.py index 7957799de6..6e4e1aaa6c 100644 --- a/disnake/message.py +++ b/disnake/message.py @@ -253,6 +253,12 @@ class Attachment(Hashable): The attachment's width, in pixels. Only applicable to images and videos. filename: :class:`str` The attachment's filename. + title: Optional[:class:`str`] + The attachment title. If the filename contained special characters, + this will be set to the original filename, without filename extension. + + .. versionadded:: 2.10 + url: :class:`str` The attachment URL. If the message this attachment was attached to is deleted, then this will 404. @@ -294,6 +300,7 @@ class Attachment(Hashable): "height", "width", "filename", + "title", "url", "proxy_url", "_http", @@ -311,6 +318,7 @@ def __init__(self, *, data: AttachmentPayload, state: ConnectionState) -> None: self.height: Optional[int] = data.get("height") self.width: Optional[int] = data.get("width") self.filename: str = data["filename"] + self.title: Optional[str] = data.get("title") self.url: str = data["url"] self.proxy_url: str = data["proxy_url"] self._http = state.http @@ -510,6 +518,8 @@ def to_dict(self) -> AttachmentPayload: result["waveform"] = b64encode(self.waveform).decode("ascii") if self._flags: result["flags"] = self._flags + if self.title: + result["title"] = self.title return result diff --git a/disnake/types/message.py b/disnake/types/message.py index 29fdfda374..0b9bbed3c3 100644 --- a/disnake/types/message.py +++ b/disnake/types/message.py @@ -34,6 +34,7 @@ class Reaction(TypedDict): class Attachment(TypedDict): id: Snowflake filename: str + title: NotRequired[str] description: NotRequired[str] content_type: NotRequired[str] size: int From d4972ab1d5afa7fb5040fe0ba31650c261784c58 Mon Sep 17 00:00:00 2001 From: Snipy7374 <100313469+Snipy7374@users.noreply.github.com> Date: Sat, 24 Aug 2024 13:46:18 +0200 Subject: [PATCH 08/14] feat(polls): implement Polls (#1176) Co-authored-by: shiftinv <8530778+shiftinv@users.noreply.github.com> Co-authored-by: Victor <67214928+Victorsitou@users.noreply.github.com> --- changelog/1175.feature.rst | 6 + disnake/__init__.py | 1 + disnake/abc.py | 21 +- disnake/enums.py | 21 ++ disnake/ext/commands/base_core.py | 1 + disnake/ext/commands/core.py | 4 + disnake/flags.py | 58 ++++ disnake/http.py | 49 ++++ disnake/interactions/base.py | 29 ++ disnake/iterators.py | 63 +++++ disnake/message.py | 11 + disnake/permissions.py | 14 + disnake/poll.py | 423 ++++++++++++++++++++++++++++++ disnake/raw_models.py | 59 +++++ disnake/state.py | 34 +++ disnake/threads.py | 1 + disnake/types/gateway.py | 18 ++ disnake/types/message.py | 2 + disnake/types/poll.py | 72 +++++ disnake/webhook/async_.py | 15 ++ docs/api/events.rst | 60 +++++ docs/api/messages.rst | 49 ++++ docs/intents.rst | 5 +- 23 files changed, 1013 insertions(+), 3 deletions(-) create mode 100644 changelog/1175.feature.rst create mode 100644 disnake/poll.py create mode 100644 disnake/types/poll.py diff --git a/changelog/1175.feature.rst b/changelog/1175.feature.rst new file mode 100644 index 0000000000..78dd79b311 --- /dev/null +++ b/changelog/1175.feature.rst @@ -0,0 +1,6 @@ +Add the new poll discord API feature. This includes the following new classes and events: + +- New types: :class:`Poll`, :class:`PollAnswer`, :class:`PollMedia`, :class:`RawMessagePollVoteActionEvent` and :class:`PollLayoutType`. +- Edited :meth:`abc.Messageable.send`, :meth:`Webhook.send`, :meth:`ext.commands.Context.send` and :meth:`disnake.InteractionResponse.send_message` to be able to send polls. +- Edited :class:`Message` to store a new :attr:`Message.poll` attribute for polls. +- Edited :class:`Event` to contain the new :func:`on_message_poll_vote_add`, :func:`on_message_poll_vote_remove`, :func:`on_raw_message_poll_vote_add` and :func:`on_raw_message_poll_vote_remove`. diff --git a/disnake/__init__.py b/disnake/__init__.py index 73b0c56349..e0af7f3354 100644 --- a/disnake/__init__.py +++ b/disnake/__init__.py @@ -57,6 +57,7 @@ from .partial_emoji import * from .permissions import * from .player import * +from .poll import * from .raw_models import * from .reaction import * from .role import * diff --git a/disnake/abc.py b/disnake/abc.py index a3c78f8644..b5b0f5509e 100644 --- a/disnake/abc.py +++ b/disnake/abc.py @@ -74,6 +74,7 @@ from .iterators import HistoryIterator from .member import Member from .message import Message, MessageReference, PartialMessage + from .poll import Poll from .state import ConnectionState from .threads import AnyThreadArchiveDuration, ForumTag from .types.channel import ( @@ -640,6 +641,7 @@ def _apply_implict_permissions(self, base: Permissions) -> None: if not base.send_messages: base.send_tts_messages = False base.send_voice_messages = False + base.send_polls = False base.mention_everyone = False base.embed_links = False base.attach_files = False @@ -887,6 +889,7 @@ async def set_permissions( request_to_speak: Optional[bool] = ..., send_messages: Optional[bool] = ..., send_messages_in_threads: Optional[bool] = ..., + send_polls: Optional[bool] = ..., send_tts_messages: Optional[bool] = ..., send_voice_messages: Optional[bool] = ..., speak: Optional[bool] = ..., @@ -1435,6 +1438,7 @@ async def send( mention_author: bool = ..., view: View = ..., components: Components[MessageUIComponent] = ..., + poll: Poll = ..., ) -> Message: ... @@ -1456,6 +1460,7 @@ async def send( mention_author: bool = ..., view: View = ..., components: Components[MessageUIComponent] = ..., + poll: Poll = ..., ) -> Message: ... @@ -1477,6 +1482,7 @@ async def send( mention_author: bool = ..., view: View = ..., components: Components[MessageUIComponent] = ..., + poll: Poll = ..., ) -> Message: ... @@ -1498,6 +1504,7 @@ async def send( mention_author: bool = ..., view: View = ..., components: Components[MessageUIComponent] = ..., + poll: Poll = ..., ) -> Message: ... @@ -1520,6 +1527,7 @@ async def send( mention_author: Optional[bool] = None, view: Optional[View] = None, components: Optional[Components[MessageUIComponent]] = None, + poll: Optional[Poll] = None, ): """|coro| @@ -1528,7 +1536,7 @@ async def send( The content must be a type that can convert to a string through ``str(content)``. At least one of ``content``, ``embed``/``embeds``, ``file``/``files``, - ``stickers``, ``components``, or ``view`` must be provided. + ``stickers``, ``components``, ``poll`` or ``view`` must be provided. To upload a single file, the ``file`` parameter should be used with a single :class:`.File` object. To upload multiple files, the ``files`` @@ -1624,6 +1632,11 @@ async def send( .. versionadded:: 2.9 + poll: :class:`.Poll` + The poll to send with the message. + + .. versionadded:: 2.10 + Raises ------ HTTPException @@ -1676,6 +1689,10 @@ async def send( if stickers is not None: stickers_payload = [sticker.id for sticker in stickers] + poll_payload = None + if poll: + poll_payload = poll._to_dict() + allowed_mentions_payload = None if allowed_mentions is None: allowed_mentions_payload = state.allowed_mentions and state.allowed_mentions.to_dict() @@ -1737,6 +1754,7 @@ async def send( message_reference=reference_payload, stickers=stickers_payload, components=components_payload, + poll=poll_payload, flags=flags_payload, ) finally: @@ -1753,6 +1771,7 @@ async def send( message_reference=reference_payload, stickers=stickers_payload, components=components_payload, + poll=poll_payload, flags=flags_payload, ) diff --git a/disnake/enums.py b/disnake/enums.py index 56b06ca5a2..6f81211156 100644 --- a/disnake/enums.py +++ b/disnake/enums.py @@ -71,6 +71,7 @@ "OnboardingPromptType", "SKUType", "EntitlementType", + "PollLayoutType", ) @@ -1215,6 +1216,14 @@ class Event(Enum): """Called when messages are bulk deleted. Represents the :func:`on_bulk_message_delete` event. """ + poll_vote_add = "poll_vote_add" + """Called when a vote is added on a `Poll`. + Represents the :func:`on_poll_vote_add` event. + """ + poll_vote_remove = "poll_vote_remove" + """Called when a vote is removed from a `Poll`. + Represents the :func:`on_poll_vote_remove` event. + """ raw_message_edit = "raw_message_edit" """Called when a message is edited regardless of the state of the internal message cache. Represents the :func:`on_raw_message_edit` event. @@ -1227,6 +1236,14 @@ class Event(Enum): """Called when a bulk delete is triggered regardless of the messages being in the internal message cache or not. Represents the :func:`on_raw_bulk_message_delete` event. """ + raw_poll_vote_add = "raw_poll_vote_add" + """Called when a vote is added on a `Poll` regardless of the internal message cache. + Represents the :func:`on_raw_poll_vote_add` event. + """ + raw_poll_vote_remove = "raw_poll_vote_remove" + """Called when a vote is removed from a `Poll` regardless of the internal message cache. + Represents the :func:`on_raw_poll_vote_remove` event. + """ reaction_add = "reaction_add" """Called when a message has a reaction added to it. Represents the :func:`on_reaction_add` event. @@ -1364,6 +1381,10 @@ class EntitlementType(Enum): application_subscription = 8 +class PollLayoutType(Enum): + default = 1 + + T = TypeVar("T") diff --git a/disnake/ext/commands/base_core.py b/disnake/ext/commands/base_core.py index 8f99381110..804d3717f6 100644 --- a/disnake/ext/commands/base_core.py +++ b/disnake/ext/commands/base_core.py @@ -671,6 +671,7 @@ def default_member_permissions( request_to_speak: bool = ..., send_messages: bool = ..., send_messages_in_threads: bool = ..., + send_polls: bool = ..., send_tts_messages: bool = ..., send_voice_messages: bool = ..., speak: bool = ..., diff --git a/disnake/ext/commands/core.py b/disnake/ext/commands/core.py index ffa74aaede..eb7d190b0e 100644 --- a/disnake/ext/commands/core.py +++ b/disnake/ext/commands/core.py @@ -2032,6 +2032,7 @@ def has_permissions( request_to_speak: bool = ..., send_messages: bool = ..., send_messages_in_threads: bool = ..., + send_polls: bool = ..., send_tts_messages: bool = ..., send_voice_messages: bool = ..., speak: bool = ..., @@ -2157,6 +2158,7 @@ def bot_has_permissions( request_to_speak: bool = ..., send_messages: bool = ..., send_messages_in_threads: bool = ..., + send_polls: bool = ..., send_tts_messages: bool = ..., send_voice_messages: bool = ..., speak: bool = ..., @@ -2260,6 +2262,7 @@ def has_guild_permissions( request_to_speak: bool = ..., send_messages: bool = ..., send_messages_in_threads: bool = ..., + send_polls: bool = ..., send_tts_messages: bool = ..., send_voice_messages: bool = ..., speak: bool = ..., @@ -2360,6 +2363,7 @@ def bot_has_guild_permissions( request_to_speak: bool = ..., send_messages: bool = ..., send_messages_in_threads: bool = ..., + send_polls: bool = ..., send_tts_messages: bool = ..., send_voice_messages: bool = ..., speak: bool = ..., diff --git a/disnake/flags.py b/disnake/flags.py index 1b26493892..406095a6d2 100644 --- a/disnake/flags.py +++ b/disnake/flags.py @@ -1028,11 +1028,13 @@ def __init__( automod_execution: bool = ..., bans: bool = ..., dm_messages: bool = ..., + dm_polls: bool = ..., dm_reactions: bool = ..., dm_typing: bool = ..., emojis: bool = ..., emojis_and_stickers: bool = ..., guild_messages: bool = ..., + guild_polls: bool = ..., guild_reactions: bool = ..., guild_scheduled_events: bool = ..., guild_typing: bool = ..., @@ -1043,6 +1045,7 @@ def __init__( message_content: bool = ..., messages: bool = ..., moderation: bool = ..., + polls: bool = ..., presences: bool = ..., reactions: bool = ..., typing: bool = ..., @@ -1598,6 +1601,61 @@ def automod(self): """ return (1 << 20) | (1 << 21) + @alias_flag_value + def polls(self): + """:class:`bool`: Whether guild and direct message polls related events are enabled. + + This is a shortcut to set or get both :attr:`guild_polls` and :attr:`dm_polls`. + + This corresponds to the following events: + + - :func:`on_poll_vote_add` (both guilds and DMs) + - :func:`on_poll_vote_remove` (both guilds and DMs) + - :func:`on_raw_poll_vote_add` (both guilds and DMs) + - :func:`on_raw_poll_vote_remove` (both guilds and DMs) + """ + return (1 << 24) | (1 << 25) + + @flag_value + def guild_polls(self): + """:class:`bool`: Whether guild polls related events are enabled. + + .. versionadded:: 2.10 + + This corresponds to the following events: + + - :func:`on_poll_vote_add` (only for guilds) + - :func:`on_poll_vote_remove` (only for guilds) + - :func:`on_raw_poll_vote_add` (only for guilds) + - :func:`on_raw_poll_vote_remove` (only for guilds) + + This also corresponds to the following attributes and classes in terms of cache: + + - :attr:`Message.poll` (only for guild messages) + - :class:`Poll` and all its attributes. + """ + return 1 << 24 + + @flag_value + def dm_polls(self): + """:class:`bool`: Whether direct message polls related events are enabled. + + .. versionadded:: 2.10 + + This corresponds to the following events: + + - :func:`on_poll_vote_add` (only for DMs) + - :func:`on_poll_vote_remove` (only for DMs) + - :func:`on_raw_poll_vote_add` (only for DMs) + - :func:`on_raw_poll_vote_remove` (only for DMs) + + This also corresponds to the following attributes and classes in terms of cache: + + - :attr:`Message.poll` (only for DM messages) + - :class:`Poll` and all its attributes. + """ + return 1 << 25 + class MemberCacheFlags(BaseFlags): """Controls the library's cache policy when it comes to members. diff --git a/disnake/http.py b/disnake/http.py index b8e9786f87..f0d157d671 100644 --- a/disnake/http.py +++ b/disnake/http.py @@ -70,6 +70,7 @@ member, message, onboarding, + poll, role, sku, sticker, @@ -528,6 +529,7 @@ def send_message( message_reference: Optional[message.MessageReference] = None, stickers: Optional[Sequence[Snowflake]] = None, components: Optional[Sequence[components.Component]] = None, + poll: Optional[poll.PollCreatePayload] = None, flags: Optional[int] = None, ) -> Response[message.Message]: r = Route("POST", "/channels/{channel_id}/messages", channel_id=channel_id) @@ -563,8 +565,50 @@ def send_message( if flags is not None: payload["flags"] = flags + if poll is not None: + payload["poll"] = poll + return self.request(r, json=payload) + def get_poll_answer_voters( + self, + channel_id: Snowflake, + message_id: Snowflake, + answer_id: int, + *, + after: Optional[Snowflake] = None, + limit: Optional[int] = None, + ) -> Response[poll.PollVoters]: + params: Dict[str, Any] = {} + + if after is not None: + params["after"] = after + if limit is not None: + params["limit"] = limit + + return self.request( + Route( + "GET", + "/channels/{channel_id}/polls/{message_id}/answers/{answer_id}", + channel_id=channel_id, + message_id=message_id, + answer_id=answer_id, + ), + params=params, + ) + + def expire_poll( + self, channel_id: Snowflake, message_id: Snowflake + ) -> Response[message.Message]: + return self.request( + Route( + "POST", + "/channels/{channel_id}/polls/{message_id}/expire", + channel_id=channel_id, + message_id=message_id, + ) + ) + def send_typing(self, channel_id: Snowflake) -> Response[None]: return self.request(Route("POST", "/channels/{channel_id}/typing", channel_id=channel_id)) @@ -582,6 +626,7 @@ def send_multipart_helper( message_reference: Optional[message.MessageReference] = None, stickers: Optional[Sequence[Snowflake]] = None, components: Optional[Sequence[components.Component]] = None, + poll: Optional[poll.PollCreatePayload] = None, flags: Optional[int] = None, ) -> Response[message.Message]: payload: Dict[str, Any] = {"tts": tts} @@ -603,6 +648,8 @@ def send_multipart_helper( payload["sticker_ids"] = stickers if flags is not None: payload["flags"] = flags + if poll: + payload["poll"] = poll multipart = to_multipart_with_attachments(payload, files) @@ -622,6 +669,7 @@ def send_files( message_reference: Optional[message.MessageReference] = None, stickers: Optional[Sequence[Snowflake]] = None, components: Optional[Sequence[components.Component]] = None, + poll: Optional[poll.PollCreatePayload] = None, flags: Optional[int] = None, ) -> Response[message.Message]: r = Route("POST", "/channels/{channel_id}/messages", channel_id=channel_id) @@ -637,6 +685,7 @@ def send_files( message_reference=message_reference, stickers=stickers, components=components, + poll=poll, flags=flags, ) diff --git a/disnake/interactions/base.py b/disnake/interactions/base.py index ce3b5cc89b..9543cabc66 100644 --- a/disnake/interactions/base.py +++ b/disnake/interactions/base.py @@ -74,6 +74,7 @@ from ..file import File from ..guild import GuildChannel, GuildMessageable from ..mentions import AllowedMentions + from ..poll import Poll from ..state import ConnectionState from ..threads import Thread from ..types.components import Modal as ModalPayload @@ -386,6 +387,7 @@ async def edit_original_response( attachments: Optional[List[Attachment]] = MISSING, view: Optional[View] = MISSING, components: Optional[Components[MessageUIComponent]] = MISSING, + poll: Poll = MISSING, suppress_embeds: bool = MISSING, flags: MessageFlags = MISSING, allowed_mentions: Optional[AllowedMentions] = None, @@ -450,6 +452,12 @@ async def edit_original_response( .. versionadded:: 2.4 + poll: :class:`Poll` + A poll. This can only be sent after a defer. If not used after a defer the + discord API ignore the field. + + .. versionadded:: 2.10 + allowed_mentions: :class:`AllowedMentions` Controls the mentions being processed in this message. See :meth:`.abc.Messageable.send` for more information. @@ -512,6 +520,7 @@ async def edit_original_response( embeds=embeds, view=view, components=components, + poll=poll, suppress_embeds=suppress_embeds, flags=flags, allowed_mentions=allowed_mentions, @@ -625,6 +634,7 @@ async def send( suppress_embeds: bool = MISSING, flags: MessageFlags = MISSING, delete_after: float = MISSING, + poll: Poll = MISSING, ) -> None: """|coro| @@ -701,6 +711,11 @@ async def send( .. versionchanged:: 2.7 Added support for ephemeral responses. + poll: :class:`Poll` + The poll to send with the message. + + .. versionadded:: 2.10 + Raises ------ HTTPException @@ -728,6 +743,7 @@ async def send( suppress_embeds=suppress_embeds, flags=flags, delete_after=delete_after, + poll=poll, ) @@ -902,6 +918,7 @@ async def send_message( suppress_embeds: bool = MISSING, flags: MessageFlags = MISSING, delete_after: float = MISSING, + poll: Poll = MISSING, ) -> None: """|coro| @@ -964,6 +981,12 @@ async def send_message( .. versionadded:: 2.9 + poll: :class:`Poll` + The poll to send with the message. + + .. versionadded:: 2.10 + + Raises ------ HTTPException @@ -1037,6 +1060,8 @@ async def send_message( if components is not MISSING: payload["components"] = components_to_dict(components) + if poll is not MISSING: + payload["poll"] = poll._to_dict() parent = self._parent adapter = async_context.get() @@ -1550,6 +1575,10 @@ class InteractionMessage(Message): A list of components in the message. guild: Optional[:class:`Guild`] The guild that the message belongs to, if applicable. + poll: Optional[:class:`Poll`] + The poll contained in this message. + + .. versionadded:: 2.10 """ __slots__ = () diff --git a/disnake/iterators.py b/disnake/iterators.py index 9e15f379b4..6d629066af 100644 --- a/disnake/iterators.py +++ b/disnake/iterators.py @@ -39,6 +39,7 @@ "MemberIterator", "GuildScheduledEventUserIterator", "EntitlementIterator", + "PollAnswerIterator", ) if TYPE_CHECKING: @@ -1140,3 +1141,65 @@ async def _after_strategy(self, retrieve: int) -> List[EntitlementPayload]: # endpoint returns items in ascending order when `after` is used self.after = Object(id=int(data[-1]["id"])) return data + + +class PollAnswerIterator(_AsyncIterator[Union["User", "Member"]]): + def __init__( + self, + message: Message, + answer_id: int, + *, + limit: Optional[int], + after: Optional[Snowflake] = None, + ) -> None: + self.channel_id: int = message.channel.id + self.message_id: int = message.id + self.answer_id: int = answer_id + self.guild: Optional[Guild] = message.guild + self.state: ConnectionState = message._state + + self.limit: Optional[int] = limit + self.after: Optional[Snowflake] = after + + self.getter = message._state.http.get_poll_answer_voters + self.users = asyncio.Queue() + + async def next(self) -> Union[User, Member]: + if self.users.empty(): + await self.fill_users() + + try: + return self.users.get_nowait() + except asyncio.QueueEmpty: + raise NoMoreItems from None + + def _get_retrieve(self) -> bool: + self.retrieve = min(self.limit, 100) if self.limit is not None else 100 + return self.retrieve > 0 + + async def fill_users(self) -> None: + if self._get_retrieve(): + after = self.after.id if self.after else None + data = ( + await self.getter( + channel_id=self.channel_id, + message_id=self.message_id, + answer_id=self.answer_id, + after=after, + limit=self.retrieve, + ) + )["users"] + + if len(data): + if self.limit is not None: + self.limit -= self.retrieve + self.after = Object(id=int(data[-1]["id"])) + + if len(data) < 100: + self.limit = 0 # terminate loop + + for element in data: + member = None + if not (self.guild is None or isinstance(self.guild, Object)): + member = self.guild.get_member(int(element["id"])) + await self.users.put(member or self.state.create_user(data=element)) diff --git a/disnake/message.py b/disnake/message.py index 6e4e1aaa6c..15fe00c5a9 100644 --- a/disnake/message.py +++ b/disnake/message.py @@ -34,6 +34,7 @@ from .member import Member from .mixins import Hashable from .partial_emoji import PartialEmoji +from .poll import Poll from .reaction import Reaction from .sticker import StickerItem from .threads import Thread @@ -906,6 +907,11 @@ class Message(Hashable): guild: Optional[:class:`Guild`] The guild that the message belongs to, if applicable. + + poll: Optional[:class:`Poll`] + The poll contained in this message. + + .. versionadded:: 2.10 """ __slots__ = ( @@ -941,6 +947,7 @@ class Message(Hashable): "stickers", "components", "guild", + "poll", "_edited_timestamp", "_role_subscription_data", ) @@ -1002,6 +1009,10 @@ def __init__( ) self.interaction: Optional[InteractionReference] = inter + self.poll: Optional[Poll] = None + if poll_data := data.get("poll"): + self.poll = Poll.from_dict(message=self, data=poll_data) + try: # if the channel doesn't have a guild attribute, we handle that self.guild = channel.guild # type: ignore diff --git a/disnake/permissions.py b/disnake/permissions.py index cf19761ae4..95a6792fe8 100644 --- a/disnake/permissions.py +++ b/disnake/permissions.py @@ -197,6 +197,7 @@ def __init__( request_to_speak: bool = ..., send_messages: bool = ..., send_messages_in_threads: bool = ..., + send_polls: bool = ..., send_tts_messages: bool = ..., send_voice_messages: bool = ..., speak: bool = ..., @@ -428,6 +429,7 @@ def text(cls) -> Self: read_message_history=True, send_tts_messages=True, send_voice_messages=True, + send_polls=True, ) @classmethod @@ -599,6 +601,7 @@ def update( request_to_speak: bool = ..., send_messages: bool = ..., send_messages_in_threads: bool = ..., + send_polls: bool = ..., send_tts_messages: bool = ..., send_voice_messages: bool = ..., speak: bool = ..., @@ -1058,6 +1061,14 @@ def send_voice_messages(self) -> int: """ return 1 << 46 + @flag_value + def send_polls(self) -> int: + """:class:`bool`: Returns ``True`` if a user can send polls. + + .. versionadded:: 2.10 + """ + return 1 << 49 + @flag_value def use_external_apps(self) -> int: """:class:`bool`: Returns ``True`` if a user's apps can send public responses. @@ -1175,6 +1186,7 @@ class PermissionOverwrite: request_to_speak: Optional[bool] send_messages: Optional[bool] send_messages_in_threads: Optional[bool] + send_polls: Optional[bool] send_tts_messages: Optional[bool] send_voice_messages: Optional[bool] speak: Optional[bool] @@ -1242,6 +1254,7 @@ def __init__( request_to_speak: Optional[bool] = ..., send_messages: Optional[bool] = ..., send_messages_in_threads: Optional[bool] = ..., + send_polls: Optional[bool] = ..., send_tts_messages: Optional[bool] = ..., send_voice_messages: Optional[bool] = ..., speak: Optional[bool] = ..., @@ -1376,6 +1389,7 @@ def update( request_to_speak: Optional[bool] = ..., send_messages: Optional[bool] = ..., send_messages_in_threads: Optional[bool] = ..., + send_polls: Optional[bool] = ..., send_tts_messages: Optional[bool] = ..., send_voice_messages: Optional[bool] = ..., speak: Optional[bool] = ..., diff --git a/disnake/poll.py b/disnake/poll.py new file mode 100644 index 0000000000..39f4140945 --- /dev/null +++ b/disnake/poll.py @@ -0,0 +1,423 @@ +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +from datetime import timedelta +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +from . import utils +from .abc import Snowflake +from .emoji import Emoji, _EmojiTag +from .enums import PollLayoutType, try_enum +from .iterators import PollAnswerIterator +from .partial_emoji import PartialEmoji + +if TYPE_CHECKING: + from datetime import datetime + + from .message import Message + from .state import ConnectionState + from .types.poll import ( + Poll as PollPayload, + PollAnswer as PollAnswerPayload, + PollCreateAnswerPayload, + PollCreateMediaPayload, + PollCreatePayload, + PollMedia as PollMediaPayload, + ) + +__all__ = ( + "PollMedia", + "PollAnswer", + "Poll", +) + + +class PollMedia: + """Represents data of a poll's question/answers. + + .. versionadded:: 2.10 + + Parameters + ---------- + text: :class:`str` + The text of this media. + emoji: Optional[Union[:class:`Emoji`, :class:`PartialEmoji`, :class:`str`]] + The emoji of this media. + + Attributes + ---------- + text: Optional[:class:`str`] + The text of this media. + emoji: Optional[:class:`PartialEmoji`] + The emoji of this media. + """ + + __slots__ = ("text", "emoji") + + def __init__( + self, text: Optional[str], *, emoji: Optional[Union[Emoji, PartialEmoji, str]] = None + ) -> None: + if text is None and emoji is None: + raise ValueError("At least one of `text` or `emoji` must be not None") + + self.text = text + self.emoji: Optional[Union[Emoji, PartialEmoji]] = None + if isinstance(emoji, str): + self.emoji = PartialEmoji.from_str(emoji) + elif isinstance(emoji, _EmojiTag): + self.emoji = emoji + else: + if emoji is not None: + raise TypeError("Emoji must be None, a str, PartialEmoji, or Emoji instance.") + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} text={self.text!r} emoji={self.emoji!r}>" + + @classmethod + def from_dict(cls, state: ConnectionState, data: PollMediaPayload) -> PollMedia: + text = data.get("text") + + emoji = None + if emoji_data := data.get("emoji"): + emoji = state._get_emoji_from_data(emoji_data) + + return cls(text=text, emoji=emoji) + + def _to_dict(self) -> PollCreateMediaPayload: + payload: PollCreateMediaPayload = {} + if self.text: + payload["text"] = self.text + if self.emoji: + if self.emoji.id: + payload["emoji"] = {"id": self.emoji.id} + else: + payload["emoji"] = {"name": self.emoji.name} + return payload + + +class PollAnswer: + """Represents a poll answer from discord. + + .. versionadded:: 2.10 + + Parameters + ---------- + media: :class:`PollMedia` + The media object to set the text and/or emoji for this answer. + + Attributes + ---------- + id: Optional[:class:`int`] + The ID of this answer. This will be ``None`` only if this object was created manually + and did not originate from the API. + media: :class:`PollMedia` + The media fields of this answer. + poll: Optional[:class:`Poll`] + The poll associated with this answer. This will be ``None`` only if this object was created manually + and did not originate from the API. + vote_count: :class:`int` + The number of votes for this answer. + self_voted: :class:`bool` + Whether the current user voted for this answer. + """ + + __slots__ = ("id", "media", "poll", "vote_count", "self_voted") + + def __init__(self, media: PollMedia) -> None: + self.id: Optional[int] = None + self.poll: Optional[Poll] = None + self.media = media + self.vote_count: int = 0 + self.self_voted: bool = False + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} media={self.media!r}>" + + @classmethod + def from_dict(cls, state: ConnectionState, poll: Poll, data: PollAnswerPayload) -> PollAnswer: + answer = cls(PollMedia.from_dict(state, data["poll_media"])) + answer.id = int(data["answer_id"]) + answer.poll = poll + + return answer + + def _to_dict(self) -> PollCreateAnswerPayload: + return {"poll_media": self.media._to_dict()} + + def voters( + self, *, limit: Optional[int] = 100, after: Optional[Snowflake] = None + ) -> PollAnswerIterator: + """Returns an :class:`AsyncIterator` representing the users that have voted for this answer. + + The ``after`` parameter must represent a member and meet the :class:`abc.Snowflake` abc. + + .. note:: + + This method works only on PollAnswer(s) objects that originate from the API and not on the ones built manually. + + Parameters + ---------- + limit: Optional[:class:`int`] + The maximum number of results to return. + If ``None``, retrieves every user who voted for this answer. + Note, however, that this would make it a slow operation. + Defaults to ``100``. + after: Optional[:class:`abc.Snowflake`] + For pagination, votes are sorted by member. + + Raises + ------ + HTTPException + Getting the voters for this answer failed. + Forbidden + Tried to get the voters for this answer without the required permissions. + ValueError + You tried to invoke this method on an object that didn't originate from the API. + + Yields + ------ + Union[:class:`User`, :class:`Member`] + The member (if retrievable) or the user that has voted + for this answer. The case where it can be a :class:`Member` is + in a guild message context. Sometimes it can be a :class:`User` + if the member has left the guild. + """ + if not (self.id is not None and self.poll and self.poll.message): + raise ValueError( + "This object was manually built. To use this method, you need to use a poll object retrieved from the Discord API." + ) + + return PollAnswerIterator(self.poll.message, self.id, limit=limit, after=after) + + +class Poll: + """Represents a poll from Discord. + + .. versionadded:: 2.10 + + Parameters + ---------- + question: Union[:class:`str`, :class:`PollMedia`] + The question of the poll. Currently, emojis are not supported in poll questions. + answers: List[Union[:class:`str`, :class:`PollAnswer`]] + The answers for this poll, up to 10. + duration: :class:`datetime.timedelta` + The total duration of the poll, up to 32 days. Defaults to 1 day. + Note that this gets rounded down to the closest hour. + allow_multiselect: :class:`bool` + Whether users will be able to pick more than one answer. Defaults to ``False``. + layout_type: :class:`PollLayoutType` + The layout type of the poll. Defaults to :attr:`PollLayoutType.default`. + + Attributes + ---------- + message: Optional[:class:`Message`] + The message which contains this poll. This will be ``None`` only if this object was created manually + and did not originate from the API. + question: :class:`PollMedia` + The question of the poll. + duration: Optional[:class:`datetime.timedelta`] + The original duration for this poll. ``None`` if the poll is a non-expiring poll. + allow_multiselect: :class:`bool` + Whether users are able to pick more than one answer. + layout_type: :class:`PollLayoutType` + The type of the layout of the poll. + is_finalized: :class:`bool` + Whether the votes have been precisely counted. + """ + + __slots__ = ( + "message", + "question", + "_answers", + "duration", + "allow_multiselect", + "layout_type", + "is_finalized", + ) + + def __init__( + self, + question: Union[str, PollMedia], + *, + answers: List[Union[str, PollAnswer]], + duration: timedelta = timedelta(hours=24), + allow_multiselect: bool = False, + layout_type: PollLayoutType = PollLayoutType.default, + ) -> None: + self.message: Optional[Message] = None + + if isinstance(question, str): + self.question = PollMedia(question) + elif isinstance(question, PollMedia): + self.question: PollMedia = question + else: + raise TypeError( + f"Expected 'str' or 'PollMedia' for 'question', got {question.__class__.__name__!r}." + ) + + self._answers: Dict[int, PollAnswer] = {} + for i, answer in enumerate(answers, 1): + if isinstance(answer, PollAnswer): + self._answers[i] = answer + elif isinstance(answer, str): + self._answers[i] = PollAnswer(PollMedia(answer)) + else: + raise TypeError( + f"Expected 'List[str]' or 'List[PollAnswer]' for 'answers', got List[{answer.__class__.__name__!r}]." + ) + + self.duration: Optional[timedelta] = duration + self.allow_multiselect: bool = allow_multiselect + self.layout_type: PollLayoutType = layout_type + self.is_finalized: bool = False + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} question={self.question!r} answers={self.answers!r}>" + + @property + def answers(self) -> List[PollAnswer]: + """List[:class:`PollAnswer`]: The list of answers for this poll. + + See also :meth:`get_answer` to get specific answers by ID. + """ + return list(self._answers.values()) + + @property + def created_at(self) -> Optional[datetime]: + """Optional[:class:`datetime.datetime`]: When this poll was created. + + ``None`` if this poll does not originate from the discord API. + """ + if not self.message: + return + return utils.snowflake_time(self.message.id) + + @property + def expires_at(self) -> Optional[datetime]: + """Optional[:class:`datetime.datetime`]: The date when this poll will expire. + + ``None`` if this poll does not originate from the discord API or if this + poll is non-expiring. + """ + # non-expiring poll + if not self.duration: + return + + created_at = self.created_at + # manually built object + if not created_at: + return + return created_at + self.duration + + @property + def remaining_duration(self) -> Optional[timedelta]: + """Optional[:class:`datetime.timedelta`]: The remaining duration for this poll. + If this poll is finalized this property will arbitrarily return a + zero valued timedelta. + + ``None`` if this poll does not originate from the discord API. + """ + if self.is_finalized: + return timedelta(hours=0) + if not self.expires_at or not self.message: + return + + return self.expires_at - utils.utcnow() + + def get_answer(self, answer_id: int, /) -> Optional[PollAnswer]: + """Return the requested poll answer. + + Parameters + ---------- + answer_id: :class:`int` + The answer id. + + Returns + ------- + Optional[:class:`PollAnswer`] + The requested answer. + """ + return self._answers.get(answer_id) + + @classmethod + def from_dict( + cls, + message: Message, + data: PollPayload, + ) -> Poll: + state = message._state + poll = cls( + question=PollMedia.from_dict(state, data["question"]), + answers=[], + allow_multiselect=data["allow_multiselect"], + layout_type=try_enum(PollLayoutType, data["layout_type"]), + ) + for answer in data["answers"]: + answer_obj = PollAnswer.from_dict(state, poll, answer) + poll._answers[int(answer["answer_id"])] = answer_obj + + poll.message = message + if expiry := data["expiry"]: + poll.duration = utils.parse_time(expiry) - utils.snowflake_time(poll.message.id) + else: + # future support for non-expiring polls + # read the foot note https://discord.com/developers/docs/resources/poll#poll-object-poll-object-structure + poll.duration = None + + if results := data.get("results"): + poll.is_finalized = results["is_finalized"] + + for answer_count in results["answer_counts"]: + try: + answer = poll._answers[int(answer_count["id"])] + except KeyError: + # this should never happen + continue + answer.vote_count = answer_count["count"] + answer.self_voted = answer_count["me_voted"] + + return poll + + def _to_dict(self) -> PollCreatePayload: + payload: PollCreatePayload = { + "question": self.question._to_dict(), + "duration": (int(self.duration.total_seconds()) // 3600), # type: ignore + "allow_multiselect": self.allow_multiselect, + "layout_type": self.layout_type.value, + "answers": [answer._to_dict() for answer in self._answers.values()], + } + return payload + + async def expire(self) -> Message: + """|coro| + + Immediately ends a poll. + + .. note:: + + This method works only on Poll(s) objects that originate + from the API and not on the ones built manually. + + Raises + ------ + HTTPException + Expiring the poll failed. + Forbidden + Tried to expire a poll without the required permissions. + ValueError + You tried to invoke this method on an object that didn't originate from the API.``` + + Returns + ------- + :class:`Message` + The message which contains the expired `Poll`. + """ + if not self.message: + raise ValueError( + "This object was manually built. To use this method, you need to use a poll object retrieved from the Discord API." + ) + + data = await self.message._state.http.expire_poll(self.message.channel.id, self.message.id) + return self.message._state.create_message(channel=self.message.channel, data=data) diff --git a/disnake/raw_models.py b/disnake/raw_models.py index 8b7e25f43a..48b8dab56d 100644 --- a/disnake/raw_models.py +++ b/disnake/raw_models.py @@ -24,6 +24,8 @@ MessageReactionRemoveEmojiEvent, MessageReactionRemoveEvent, MessageUpdateEvent, + PollVoteAddEvent, + PollVoteRemoveEvent, PresenceUpdateEvent, ThreadDeleteEvent, TypingStartEvent, @@ -45,6 +47,7 @@ "RawTypingEvent", "RawGuildMemberRemoveEvent", "RawPresenceUpdateEvent", + "RawPollVoteActionEvent", ) @@ -147,6 +150,62 @@ def __init__(self, data: MessageUpdateEvent) -> None: self.guild_id: Optional[int] = None +PollEventType = Literal["POLL_VOTE_ADD", "POLL_VOTE_REMOVE"] + + +class RawPollVoteActionEvent(_RawReprMixin): + """Represents the event payload for :func:`on_raw_poll_vote_add` and + :func:`on_raw_poll_vote_remove` events. + + .. versionadded:: 2.10 + + Attributes + ---------- + message_id: :class:`int` + The message ID that got or lost a vote. + user_id: :class:`int` + The user ID who added the vote or whose vote was removed. + cached_member: Optional[:class:`Member`] + The member who added the vote. Available only when the guilds and members are cached. + channel_id: :class:`int` + The channel ID where the vote addition or removal took place. + guild_id: Optional[:class:`int`] + The guild ID where the vote addition or removal took place, if applicable. + answer_id: :class:`int` + The ID of the answer that was voted or unvoted. + event_type: :class:`str` + The event type that triggered this action. Can be + ``POLL_VOTE_ADD`` for vote addition or + ``POLL_VOTE_REMOVE`` for vote removal. + """ + + __slots__ = ( + "message_id", + "user_id", + "cached_member", + "channel_id", + "guild_id", + "event_type", + "answer_id", + ) + + def __init__( + self, + data: Union[PollVoteAddEvent, PollVoteRemoveEvent], + event_type: PollEventType, + ) -> None: + self.message_id: int = int(data["message_id"]) + self.user_id: int = int(data["user_id"]) + self.cached_member: Optional[Member] = None + self.channel_id: int = int(data["channel_id"]) + self.event_type = event_type + self.answer_id: int = int(data["answer_id"]) + try: + self.guild_id: Optional[int] = int(data["guild_id"]) + except KeyError: + self.guild_id: Optional[int] = None + + ReactionEventType = Literal["REACTION_ADD", "REACTION_REMOVE"] diff --git a/disnake/state.py b/disnake/state.py index f4885513d7..ab4e5a8d78 100644 --- a/disnake/state.py +++ b/disnake/state.py @@ -70,6 +70,7 @@ RawIntegrationDeleteEvent, RawMessageDeleteEvent, RawMessageUpdateEvent, + RawPollVoteActionEvent, RawPresenceUpdateEvent, RawReactionActionEvent, RawReactionClearEmojiEvent, @@ -930,6 +931,39 @@ def parse_message_reaction_remove_emoji( if reaction: self.dispatch("reaction_clear_emoji", reaction) + def _handle_poll_event( + self, raw: RawPollVoteActionEvent, event_type: Literal["add", "remove"] + ) -> None: + guild = self._get_guild(raw.guild_id) + answer = None + if guild is not None: + member = guild.get_member(raw.user_id) + message = self._get_message(raw.message_id) + if message is not None and message.poll is not None: + answer = message.poll.get_answer(raw.answer_id) + + if member is not None: + raw.cached_member = member + + if answer is not None: + if event_type == "add": + answer.vote_count += 1 + else: + answer.vote_count -= 1 + + self.dispatch(f"raw_poll_vote_{event_type}", raw) + + if raw.cached_member is not None and answer is not None: + self.dispatch(f"poll_vote_{event_type}", raw.cached_member, answer) + + def parse_message_poll_vote_add(self, data: gateway.PollVoteAddEvent) -> None: + raw = RawPollVoteActionEvent(data, "POLL_VOTE_ADD") + self._handle_poll_event(raw, "add") + + def parse_message_poll_vote_remove(self, data: gateway.PollVoteRemoveEvent) -> None: + raw = RawPollVoteActionEvent(data, "POLL_VOTE_REMOVE") + self._handle_poll_event(raw, "remove") + def parse_interaction_create(self, data: gateway.InteractionCreateEvent) -> None: # note: this does not use an intermediate variable for `data["type"]` since # it wouldn't allow automatically narrowing the `data` union type based diff --git a/disnake/threads.py b/disnake/threads.py index 8095bbc9a3..b7a11f25c9 100644 --- a/disnake/threads.py +++ b/disnake/threads.py @@ -474,6 +474,7 @@ def permissions_for( if not base.send_messages_in_threads: base.send_tts_messages = False base.send_voice_messages = False + base.send_polls = False base.mention_everyone = False base.embed_links = False base.attach_files = False diff --git a/disnake/types/gateway.py b/disnake/types/gateway.py index 9e81523d29..2926786e5d 100644 --- a/disnake/types/gateway.py +++ b/disnake/types/gateway.py @@ -323,6 +323,24 @@ class MessageReactionRemoveEmojiEvent(TypedDict): emoji: PartialEmoji +# https://discord.com/developers/docs/topics/gateway-events#message-poll-vote-add +class PollVoteAddEvent(TypedDict): + channel_id: Snowflake + guild_id: NotRequired[Snowflake] + message_id: Snowflake + user_id: Snowflake + answer_id: int + + +# https://discord.com/developers/docs/topics/gateway-events#message-poll-vote-remove +class PollVoteRemoveEvent(TypedDict): + channel_id: Snowflake + guild_id: NotRequired[Snowflake] + message_id: Snowflake + user_id: Snowflake + answer_id: int + + # https://discord.com/developers/docs/topics/gateway-events#interaction-create InteractionCreateEvent = BaseInteraction diff --git a/disnake/types/message.py b/disnake/types/message.py index 0b9bbed3c3..424b7ffd66 100644 --- a/disnake/types/message.py +++ b/disnake/types/message.py @@ -12,6 +12,7 @@ from .emoji import PartialEmoji from .interactions import InteractionMessageReference from .member import Member, UserWithMember +from .poll import Poll from .snowflake import Snowflake, SnowflakeList from .sticker import StickerItem from .threads import Thread @@ -115,6 +116,7 @@ class Message(TypedDict): sticker_items: NotRequired[List[StickerItem]] position: NotRequired[int] role_subscription_data: NotRequired[RoleSubscriptionData] + poll: NotRequired[Poll] # specific to MESSAGE_CREATE/MESSAGE_UPDATE events guild_id: NotRequired[Snowflake] diff --git a/disnake/types/poll.py b/disnake/types/poll.py new file mode 100644 index 0000000000..37d33e2206 --- /dev/null +++ b/disnake/types/poll.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +from typing import List, Literal, Optional, TypedDict + +from typing_extensions import NotRequired + +from .emoji import PartialEmoji +from .snowflake import Snowflake +from .user import User + + +class PollMedia(TypedDict): + text: NotRequired[str] + emoji: NotRequired[PartialEmoji] + + +class PollAnswer(TypedDict): + # sent only as part of responses from Discord's API/Gateway + answer_id: Snowflake + poll_media: PollMedia + + +PollLayoutType = Literal[1] + + +class PollAnswerCount(TypedDict): + id: Snowflake + count: int + me_voted: bool + + +class PollResult(TypedDict): + is_finalized: bool + answer_counts: List[PollAnswerCount] + + +class PollVoters(TypedDict): + users: List[User] + + +class Poll(TypedDict): + question: PollMedia + answers: List[PollAnswer] + expiry: Optional[str] + allow_multiselect: bool + layout_type: PollLayoutType + # sent only as part of responses from Discord's API/Gateway + results: NotRequired[PollResult] + + +class EmojiPayload(TypedDict): + id: NotRequired[int] + name: NotRequired[str] + + +class PollCreateMediaPayload(TypedDict): + text: NotRequired[str] + emoji: NotRequired[EmojiPayload] + + +class PollCreateAnswerPayload(TypedDict): + poll_media: PollCreateMediaPayload + + +class PollCreatePayload(TypedDict): + question: PollCreateMediaPayload + answers: List[PollCreateAnswerPayload] + duration: int + allow_multiselect: bool + layout_type: NotRequired[int] diff --git a/disnake/webhook/async_.py b/disnake/webhook/async_.py index 7e1228a529..98650f4bf1 100644 --- a/disnake/webhook/async_.py +++ b/disnake/webhook/async_.py @@ -63,6 +63,7 @@ from ..http import Response from ..mentions import AllowedMentions from ..message import Attachment + from ..poll import Poll from ..state import ConnectionState from ..sticker import GuildSticker, StandardSticker, StickerItem from ..types.message import Message as MessagePayload @@ -511,6 +512,7 @@ def handle_message_parameters_dict( allowed_mentions: Optional[AllowedMentions] = MISSING, previous_allowed_mentions: Optional[AllowedMentions] = None, stickers: Sequence[Union[GuildSticker, StandardSticker, StickerItem]] = MISSING, + poll: Poll = MISSING, # these parameters are exclusive to webhooks in forum/media channels thread_name: str = MISSING, applied_tags: Sequence[Snowflake] = MISSING, @@ -579,6 +581,8 @@ def handle_message_parameters_dict( payload["thread_name"] = thread_name if applied_tags: payload["applied_tags"] = [t.id for t in applied_tags] + if poll is not MISSING: + payload["poll"] = poll._to_dict() return DictPayloadParameters(payload=payload, files=files) @@ -602,6 +606,7 @@ def handle_message_parameters( allowed_mentions: Optional[AllowedMentions] = MISSING, previous_allowed_mentions: Optional[AllowedMentions] = None, stickers: Sequence[Union[GuildSticker, StandardSticker, StickerItem]] = MISSING, + poll: Poll = MISSING, # these parameters are exclusive to webhooks in forum/media channels thread_name: str = MISSING, applied_tags: Sequence[Snowflake] = MISSING, @@ -626,6 +631,7 @@ def handle_message_parameters( stickers=stickers, thread_name=thread_name, applied_tags=applied_tags, + poll=poll, ) if params.files: @@ -1495,6 +1501,7 @@ async def send( allowed_mentions: AllowedMentions = ..., view: View = ..., components: Components[MessageUIComponent] = ..., + poll: Poll = ..., thread: Snowflake = ..., thread_name: str = ..., applied_tags: Sequence[Snowflake] = ..., @@ -1521,6 +1528,7 @@ async def send( allowed_mentions: AllowedMentions = ..., view: View = ..., components: Components[MessageUIComponent] = ..., + poll: Poll = ..., thread: Snowflake = ..., thread_name: str = ..., applied_tags: Sequence[Snowflake] = ..., @@ -1551,6 +1559,7 @@ async def send( applied_tags: Sequence[Snowflake] = MISSING, wait: bool = False, delete_after: float = MISSING, + poll: Poll = MISSING, ) -> Optional[WebhookMessage]: """|coro| @@ -1677,6 +1686,11 @@ async def send( .. versionadded:: 2.9 + poll: :class:`Poll` + The poll to send with the message. + + .. versionadded:: 2.10 + Raises ------ HTTPException @@ -1749,6 +1763,7 @@ async def send( applied_tags=applied_tags, allowed_mentions=allowed_mentions, previous_allowed_mentions=previous_mentions, + poll=poll, ) adapter = async_context.get() diff --git a/docs/api/events.rst b/docs/api/events.rst index aeb441f9a3..fb0059f1fa 100644 --- a/docs/api/events.rst +++ b/docs/api/events.rst @@ -1243,6 +1243,38 @@ This section documents events related to Discord chat messages. :param messages: The messages that have been deleted. :type messages: List[:class:`Message`] +.. function:: on_poll_vote_add(member, answer) + + Called when a vote is added on a poll. If the member or message is not found in the internal cache, then this event will not be called. + + This requires :attr:`Intents.guild_polls` or :attr:`Intents.dm_polls` to be enabled to receive events about polls sent in guilds or DMs. + + .. note:: + + You can use :attr:`Intents.polls` to enable both :attr:`Intents.guild_polls` and :attr:`Intents.dm_polls` in one go. + + + :param member: The member who voted. + :type member: :class:`Member` + :param answer: The :class:`PollAnswer` object for which the vote was added. + :type answer: :class:`PollAnswer` + +.. function:: on_poll_vote_remove(member, answer) + + Called when a vote is removed on a poll. If the member or message is not found in the internal cache, then this event will not be called. + + This requires :attr:`Intents.guild_polls` or :attr:`Intents.dm_polls` to be enabled to receive events about polls sent in guilds or DMs. + + .. note:: + + You can use :attr:`Intents.polls` to enable both :attr:`Intents.guild_polls` and :attr:`Intents.dm_polls` in one go. + + + :param member: The member who removed the vote. + :type member: :class:`Member` + :param answer: The :class:`PollAnswer` object for which the vote was removed. + :type answer: :class:`PollAnswer` + .. function:: on_raw_message_edit(payload) Called when a message is edited. Unlike :func:`on_message_edit`, this is called @@ -1293,6 +1325,34 @@ This section documents events related to Discord chat messages. :param payload: The raw event payload data. :type payload: :class:`RawBulkMessageDeleteEvent` +.. function:: on_raw_poll_vote_add(payload) + + Called when a vote is added on a poll. Unlike :func:`on_poll_vote_add`, this is + called regardless of the guilds being in the internal guild cache or not. + + This requires :attr:`Intents.guild_polls` or :attr:`Intents.dm_polls` to be enabled to receive events about polls sent in guilds or DMs. + + .. note:: + + You can use :attr:`Intents.polls` to enable both :attr:`Intents.guild_polls` and :attr:`Intents.dm_polls` in one go. + + :param payload: The raw event payload data. + :type payload: :class:`RawPollVoteActionEvent` + +.. function:: on_raw_poll_vote_remove(payload) + + Called when a vote is removed on a poll. Unlike :func:`on_poll_vote_remove`, this is + called regardless of the guilds being in the internal guild cache or not. + + This requires :attr:`Intents.guild_polls` or :attr:`Intents.dm_polls` to be enabled to receive events about polls sent in guilds or DMs. + + .. note:: + + You can use :attr:`Intents.polls` to enable both :attr:`Intents.guild_polls` and :attr:`Intents.dm_polls` in one go. + + :param payload: The raw event payload data. + :type payload: :class:`RawPollVoteActionEvent` + .. function:: on_reaction_add(reaction, user) Called when a message has a reaction added to it. Similar to :func:`on_message_edit`, diff --git a/docs/api/messages.rst b/docs/api/messages.rst index e0a7c8a513..6c02bc0208 100644 --- a/docs/api/messages.rst +++ b/docs/api/messages.rst @@ -94,6 +94,14 @@ RawMessageUpdateEvent .. autoclass:: RawMessageUpdateEvent() :members: +RawPollVoteActionEvent +~~~~~~~~~~~~~~~~~~~~~~ + +.. attributetable:: RawPollVoteActionEvent + +.. autoclass:: RawPollVoteActionEvent() + :members: + RawReactionActionEvent ~~~~~~~~~~~~~~~~~~~~~~ @@ -177,6 +185,30 @@ PartialMessage .. autoclass:: PartialMessage :members: +Poll +~~~~ + +.. attributetable:: Poll + +.. autoclass:: Poll + :members: + +PollAnswer +~~~~~~~~~~ + +.. attributetable:: PollAnswer + +.. autoclass:: PollAnswer + :members: + +PollMedia +~~~~~~~~~ + +.. attributetable:: PollMedia + +.. autoclass:: PollMedia + :members: + Enumerations ------------ @@ -369,6 +401,19 @@ MessageType .. versionadded:: 2.10 +PollLayoutType +~~~~~~~~~~~~~~ + +.. class:: PollLayoutType + + Specifies the layout of a :class:`Poll`. + + .. versionadded:: 2.10 + + .. attribute:: default + + The default poll layout type. + Events ------ @@ -376,10 +421,14 @@ Events - :func:`on_message_edit(before, after) ` - :func:`on_message_delete(message) ` - :func:`on_bulk_message_delete(messages) ` +- :func:`on_poll_vote_add(member, answer) ` +- :func:`on_poll_vote_removed(member, answer) ` - :func:`on_raw_message_edit(payload) ` - :func:`on_raw_message_delete(payload) ` - :func:`on_raw_bulk_message_delete(payload) ` +- :func:`on_raw_poll_vote_add(payload) ` +- :func:`on_raw_poll_vote_remove(payload) ` - :func:`on_reaction_add(reaction, user) ` - :func:`on_reaction_remove(reaction, user) ` diff --git a/docs/intents.rst b/docs/intents.rst index 8c4a41c3ae..79c3872aaa 100644 --- a/docs/intents.rst +++ b/docs/intents.rst @@ -98,7 +98,7 @@ Message Content Intent ++++++++++++++++++++++ - Whether you want a prefix that isn't the bot mention. -- Whether you want to access the contents of messages. This includes content (text), embeds, attachments, and components. +- Whether you want to access the contents of messages. This includes content (text), embeds, attachments, components and polls. .. _need_presence_intent: @@ -191,12 +191,13 @@ As of August 31st, 2022, Discord has blocked message content from being sent to If you are on version 2.4 or before, your bot will be able to access message content without the intent enabled in the code. However, as of version 2.5, it is required to enable :attr:`Intents.message_content` to receive message content over the gateway. -Message content refers to four attributes on the :class:`.Message` object: +Message content refers to five attributes on the :class:`.Message` object: - :attr:`~.Message.content` - :attr:`~.Message.embeds` - :attr:`~.Message.attachments` - :attr:`~.Message.components` +- :attr:`~.Message.poll` You will always receive message content in the following cases even without the message content intent: From 975657a743d0eb056025f21758c24deeb525738e Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Sat, 24 Aug 2024 13:55:50 +0200 Subject: [PATCH 09/14] feat(interactions): deserialize `channel` from data (#1012) --- changelog/1012.feature.rst | 1 + disnake/interactions/application_command.py | 26 ++-- disnake/interactions/base.py | 125 +++++++++++--------- disnake/interactions/message.py | 26 ++-- disnake/interactions/modal.py | 17 ++- disnake/state.py | 54 +++++++++ disnake/types/interactions.py | 9 +- tests/interactions/test_base.py | 42 ++++--- 8 files changed, 203 insertions(+), 97 deletions(-) create mode 100644 changelog/1012.feature.rst diff --git a/changelog/1012.feature.rst b/changelog/1012.feature.rst new file mode 100644 index 0000000000..a59debb4e2 --- /dev/null +++ b/changelog/1012.feature.rst @@ -0,0 +1 @@ +:class:`Interaction`\s now always have a proper :attr:`~Interaction.channel` attribute, even when the bot is not part of the guild or cannot access the channel due to other reasons. diff --git a/disnake/interactions/application_command.py b/disnake/interactions/application_command.py index d25f6a2530..46eee43985 100644 --- a/disnake/interactions/application_command.py +++ b/disnake/interactions/application_command.py @@ -58,8 +58,21 @@ class ApplicationCommandInteraction(Interaction[ClientT]): The application ID that the interaction was for. guild_id: Optional[:class:`int`] The guild ID the interaction was sent from. - channel_id: :class:`int` - The channel ID the interaction was sent from. + channel: Union[:class:`abc.GuildChannel`, :class:`Thread`, :class:`PartialMessageable`] + The channel the interaction was sent from. + + Note that due to a Discord limitation, DM channels + are not resolved as there is no data to complete them. + These are :class:`PartialMessageable` instead. + + .. versionchanged:: 2.10 + If the interaction was sent from a thread and the bot cannot normally access the thread, + this is now a proper :class:`Thread` object. + + .. note:: + If you want to compute the interaction author's or bot's permissions in the channel, + consider using :attr:`permissions` or :attr:`app_permissions`. + author: Union[:class:`User`, :class:`Member`] The user or member that sent the interaction. locale: :class:`Locale` @@ -103,7 +116,7 @@ def __init__( ) -> None: super().__init__(data=data, state=state) self.data: ApplicationCommandInteractionData = ApplicationCommandInteractionData( - data=data["data"], state=state, guild_id=self.guild_id + data=data["data"], parent=self ) self.application_command: InvokableApplicationCommand = MISSING self.command_failed: bool = False @@ -200,17 +213,14 @@ def __init__( self, *, data: ApplicationCommandInteractionDataPayload, - state: ConnectionState, - guild_id: Optional[int], + parent: ApplicationCommandInteraction[ClientT], ) -> None: super().__init__(data) self.id: int = int(data["id"]) self.name: str = data["name"] self.type: ApplicationCommandType = try_enum(ApplicationCommandType, data["type"]) - self.resolved = InteractionDataResolved( - data=data.get("resolved", {}), state=state, guild_id=guild_id - ) + self.resolved = InteractionDataResolved(data=data.get("resolved", {}), parent=parent) self.target_id: Optional[int] = utils._get_as_snowflake(data, "target_id") target = self.resolved.get_by_id(self.target_id) self.target: Optional[Union[User, Member, Message]] = target # type: ignore diff --git a/disnake/interactions/base.py b/disnake/interactions/base.py index 9543cabc66..7e43874bcd 100644 --- a/disnake/interactions/base.py +++ b/disnake/interactions/base.py @@ -21,10 +21,9 @@ from .. import utils from ..app_commands import OptionChoice -from ..channel import PartialMessageable, _threaded_guild_channel_factory +from ..channel import PartialMessageable from ..entitlement import Entitlement from ..enums import ( - ChannelType, ComponentType, InteractionResponseType, InteractionType, @@ -76,7 +75,6 @@ from ..mentions import AllowedMentions from ..poll import Poll from ..state import ConnectionState - from ..threads import Thread from ..types.components import Modal as ModalPayload from ..types.interactions import ( ApplicationCommandOptionChoice as ApplicationCommandOptionChoicePayload, @@ -90,7 +88,8 @@ from .message import MessageInteraction from .modal import ModalInteraction - InteractionChannel = Union[GuildChannel, Thread, PartialMessageable] + InteractionMessageable = Union[GuildMessageable, PartialMessageable] + InteractionChannel = Union[InteractionMessageable, GuildChannel] AnyBot = Union[Bot, AutoShardedBot] @@ -131,8 +130,21 @@ class Interaction(Generic[ClientT]): .. versionchanged:: 2.5 Changed to :class:`Locale` instead of :class:`str`. - channel_id: :class:`int` - The channel ID the interaction was sent from. + channel: Union[:class:`abc.GuildChannel`, :class:`Thread`, :class:`PartialMessageable`] + The channel the interaction was sent from. + + Note that due to a Discord limitation, DM channels + are not resolved as there is no data to complete them. + These are :class:`PartialMessageable` instead. + + .. versionchanged:: 2.10 + If the interaction was sent from a thread and the bot cannot normally access the thread, + this is now a proper :class:`Thread` object. + + .. note:: + If you want to compute the interaction author's or bot's permissions in the channel, + consider using :attr:`permissions` or :attr:`app_permissions`. + author: Union[:class:`User`, :class:`Member`] The user or member that sent the interaction. locale: :class:`Locale` @@ -159,7 +171,7 @@ class Interaction(Generic[ClientT]): "id", "type", "guild_id", - "channel_id", + "channel", "application_id", "author", "token", @@ -175,7 +187,6 @@ class Interaction(Generic[ClientT]): "_original_response", "_cs_response", "_cs_followup", - "_cs_channel", "_cs_me", "_cs_expires_at", ) @@ -193,8 +204,6 @@ def __init__(self, *, data: InteractionPayload, state: ConnectionState) -> None: self.token: str = data["token"] self.version: int = data["version"] self.application_id: int = int(data["application_id"]) - - self.channel_id: int = int(data["channel_id"]) self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id") self.locale: Locale = try_enum(Locale, data["locale"]) @@ -208,17 +217,29 @@ def __init__(self, *, data: InteractionPayload, state: ConnectionState) -> None: # one of user and member will always exist self.author: Union[User, Member] = MISSING - if self.guild_id and (member := data.get("member")): - guild: Guild = self.guild or Object(id=self.guild_id) # type: ignore + guild_fallback: Optional[Union[Guild, Object]] = None + if self.guild_id: + guild_fallback = self.guild or Object(self.guild_id) + + if guild_fallback and (member := data.get("member")): self.author = ( - isinstance(guild, Guild) - and guild.get_member(int(member["user"]["id"])) - or Member(state=self._state, guild=guild, data=member) + isinstance(guild_fallback, Guild) + and guild_fallback.get_member(int(member["user"]["id"])) + or Member( + state=self._state, + guild=guild_fallback, # type: ignore # may be `Object` + data=member, + ) ) self._permissions = int(member.get("permissions", 0)) elif user := data.get("user"): self.author = self._state.store_user(user) + # TODO: consider making this optional in 3.0 + self.channel: InteractionMessageable = state._get_partial_interaction_channel( + data["channel"], guild_fallback, return_messageable=True + ) + self.entitlements: List[Entitlement] = ( [Entitlement(data=e, state=state) for e in entitlements_data] if (entitlements_data := data.get("entitlements")) @@ -256,24 +277,13 @@ def me(self) -> Union[Member, ClientUser]: return None if self.bot is None else self.bot.user # type: ignore return self.guild.me - @utils.cached_slot_property("_cs_channel") - def channel(self) -> Union[GuildMessageable, PartialMessageable]: - """Union[:class:`abc.GuildChannel`, :class:`Thread`, :class:`PartialMessageable`]: The channel the interaction was sent from. - - Note that due to a Discord limitation, threads that the bot cannot access and DM channels - are not resolved since there is no data to complete them. - These are :class:`PartialMessageable` instead. + @property + def channel_id(self) -> int: + """The channel ID the interaction was sent from. - If you want to compute the interaction author's or bot's permissions in the channel, - consider using :attr:`permissions` or :attr:`app_permissions` instead. + See also :attr:`channel`. """ - guild = self.guild - channel = guild and guild._resolve_channel(self.channel_id) - if channel is None: - # could be a thread channel in a guild, or a DM channel - type = None if self.guild_id is not None else ChannelType.private - return PartialMessageable(state=self._state, id=self.channel_id, type=type) - return channel # type: ignore + return self.channel.id @property def permissions(self) -> Permissions: @@ -1873,8 +1883,7 @@ def __init__( self, *, data: InteractionDataResolvedPayload, - state: ConnectionState, - guild_id: Optional[int], + parent: Interaction[ClientT], ) -> None: data = data or {} super().__init__(data) @@ -1893,6 +1902,9 @@ def __init__( messages = data.get("messages", {}) attachments = data.get("attachments", {}) + state = parent._state + guild_id = parent.guild_id + guild: Optional[Guild] = None # `guild_fallback` is only used in guild contexts, so this `MISSING` value should never be used. # We need to define it anyway to satisfy the typechecker. @@ -1925,36 +1937,35 @@ def __init__( data=role, ) - for str_id, channel in channels.items(): - channel_id = int(str_id) - factory, _ = _threaded_guild_channel_factory(channel["type"]) - if factory: - channel["position"] = 0 # type: ignore - self.channels[channel_id] = ( - guild - and guild.get_channel_or_thread(channel_id) - or factory( - guild=guild_fallback, - state=state, - data=channel, # type: ignore - ) - ) - else: - # TODO: guild_directory is not messageable - self.channels[channel_id] = PartialMessageable( - state=state, id=channel_id, type=try_enum(ChannelType, channel["type"]) - ) + for str_id, channel_data in channels.items(): + self.channels[int(str_id)] = state._get_partial_interaction_channel( + channel_data, guild_fallback + ) for str_id, message in messages.items(): channel_id = int(message["channel_id"]) - channel = cast( - "Optional[MessageableChannel]", - (guild and guild.get_channel(channel_id) or state.get_channel(channel_id)), - ) + channel: Optional[MessageableChannel] = None + + if ( + channel_id == parent.channel.id + # we still want to fall back to state.get_channel when the + # parent channel is a dm/group channel, for now. + # FIXME: remove this once `parent.channel` supports `DMChannel` + and not isinstance(parent.channel, PartialMessageable) + ): + # fast path, this should generally be the case + channel = parent.channel + else: + channel = cast( + "Optional[MessageableChannel]", + (guild and guild.get_channel(channel_id) or state.get_channel(channel_id)), + ) + if channel is None: - # The channel is not part of `resolved.channels`, + # n.b. the message's channel is not sent as part of `resolved.channels`, # so we need to fall back to partials here. channel = PartialMessageable(state=state, id=channel_id, type=None) + self.messages[int(str_id)] = Message(state=state, channel=channel, data=message) for str_id, attachment in attachments.items(): diff --git a/disnake/interactions/message.py b/disnake/interactions/message.py index 4ef51165d5..8ce8c3d3ab 100644 --- a/disnake/interactions/message.py +++ b/disnake/interactions/message.py @@ -47,8 +47,21 @@ class MessageInteraction(Interaction[ClientT]): The token to continue the interaction. These are valid for 15 minutes. guild_id: Optional[:class:`int`] The guild ID the interaction was sent from. - channel_id: :class:`int` - The channel ID the interaction was sent from. + channel: Union[:class:`abc.GuildChannel`, :class:`Thread`, :class:`PartialMessageable`] + The channel the interaction was sent from. + + Note that due to a Discord limitation, DM channels + are not resolved as there is no data to complete them. + These are :class:`PartialMessageable` instead. + + .. versionchanged:: 2.10 + If the interaction was sent from a thread and the bot cannot normally access the thread, + this is now a proper :class:`Thread` object. + + .. note:: + If you want to compute the interaction author's or bot's permissions in the channel, + consider using :attr:`permissions` or :attr:`app_permissions`. + author: Union[:class:`User`, :class:`Member`] The user or member that sent the interaction. locale: :class:`Locale` @@ -85,9 +98,7 @@ class MessageInteraction(Interaction[ClientT]): def __init__(self, *, data: MessageInteractionPayload, state: ConnectionState) -> None: super().__init__(data=data, state=state) - self.data: MessageInteractionData = MessageInteractionData( - data=data["data"], state=state, guild_id=self.guild_id - ) + self.data: MessageInteractionData = MessageInteractionData(data=data["data"], parent=self) self.message = Message(state=self._state, channel=self.channel, data=data["message"]) @property @@ -167,8 +178,7 @@ def __init__( self, *, data: MessageComponentInteractionDataPayload, - state: ConnectionState, - guild_id: Optional[int], + parent: MessageInteraction[ClientT], ) -> None: super().__init__(data) self.custom_id: str = data["custom_id"] @@ -179,7 +189,7 @@ def __init__( empty_resolved: InteractionDataResolvedPayload = {} # pyright shenanigans self.resolved = InteractionDataResolved( - data=data.get("resolved", empty_resolved), state=state, guild_id=guild_id + data=data.get("resolved", empty_resolved), parent=parent ) def __repr__(self) -> str: diff --git a/disnake/interactions/modal.py b/disnake/interactions/modal.py index f631c38ac2..be9520b1cf 100644 --- a/disnake/interactions/modal.py +++ b/disnake/interactions/modal.py @@ -39,8 +39,21 @@ class ModalInteraction(Interaction[ClientT]): These are valid for 15 minutes. guild_id: Optional[:class:`int`] The guild ID the interaction was sent from. - channel_id: :class:`int` - The channel ID the interaction was sent from. + channel: Union[:class:`abc.GuildChannel`, :class:`Thread`, :class:`PartialMessageable`] + The channel the interaction was sent from. + + Note that due to a Discord limitation, DM channels + are not resolved as there is no data to complete them. + These are :class:`PartialMessageable` instead. + + .. versionchanged:: 2.10 + If the interaction was sent from a thread and the bot cannot normally access the thread, + this is now a proper :class:`Thread` object. + + .. note:: + If you want to compute the interaction author's or bot's permissions in the channel, + consider using :attr:`permissions` or :attr:`app_permissions`. + author: Union[:class:`User`, :class:`Member`] The user or member that sent the interaction. locale: :class:`Locale` diff --git a/disnake/state.py b/disnake/state.py index ab4e5a8d78..84798d2fa5 100644 --- a/disnake/state.py +++ b/disnake/state.py @@ -43,6 +43,7 @@ TextChannel, VoiceChannel, _guild_channel_factory, + _threaded_guild_channel_factory, ) from .emoji import Emoji from .entitlement import Entitlement @@ -96,11 +97,13 @@ from .gateway import DiscordWebSocket from .guild import GuildChannel, VocalGuildChannel from .http import HTTPClient + from .interactions.base import InteractionChannel, InteractionMessageable from .types import gateway from .types.activity import Activity as ActivityPayload from .types.channel import DMChannel as DMChannelPayload from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload from .types.guild import Guild as GuildPayload, UnavailableGuild as UnavailableGuildPayload + from .types.interactions import InteractionChannel as InteractionChannelPayload from .types.message import Message as MessagePayload from .types.sticker import GuildSticker as GuildStickerPayload from .types.user import User as UserPayload @@ -2029,6 +2032,57 @@ def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmo except KeyError: return emoji + @overload + def _get_partial_interaction_channel( + self, + data: InteractionChannelPayload, + guild: Optional[Union[Guild, Object]], + *, + return_messageable: Literal[False] = False, + ) -> InteractionChannel: + ... + + @overload + def _get_partial_interaction_channel( + self, + data: InteractionChannelPayload, + guild: Optional[Union[Guild, Object]], + *, + return_messageable: Literal[True], + ) -> InteractionMessageable: + ... + + # note: this resolves private channels (and unknown types) to `PartialMessageable` + def _get_partial_interaction_channel( + self, + data: InteractionChannelPayload, + guild: Optional[Union[Guild, Object]], + *, + # this param is purely for type-checking, it has no effect on runtime behavior. + return_messageable: bool = False, + ) -> InteractionChannel: + channel_id = int(data["id"]) + channel_type = data["type"] + + factory, _ = _threaded_guild_channel_factory(channel_type) + if not factory or not guild: + return PartialMessageable( + state=self, + id=channel_id, + type=try_enum(ChannelType, channel_type), + ) + + data.setdefault("position", 0) # type: ignore + return ( + isinstance(guild, Guild) + and guild.get_channel_or_thread(channel_id) + or factory( + guild=guild, # type: ignore # FIXME: create proper fallback guild instead of passing Object + state=self, + data=data, # type: ignore # generic payload type + ) + ) + def get_channel(self, id: Optional[int]) -> Optional[Union[Channel, Thread]]: if id is None: return None diff --git a/disnake/types/interactions.py b/disnake/types/interactions.py index efffa8e599..88498da81f 100644 --- a/disnake/types/interactions.py +++ b/disnake/types/interactions.py @@ -89,7 +89,7 @@ class GuildApplicationCommandPermissions(TypedDict): InteractionType = Literal[1, 2, 3, 4, 5] -class ResolvedPartialChannel(TypedDict): +class InteractionChannel(TypedDict): id: Snowflake type: ChannelType permissions: str @@ -104,7 +104,7 @@ class InteractionDataResolved(TypedDict, total=False): users: Dict[Snowflake, User] members: Dict[Snowflake, Member] roles: Dict[Snowflake, Role] - channels: Dict[Snowflake, ResolvedPartialChannel] + channels: Dict[Snowflake, InteractionChannel] # only in application commands messages: Dict[Snowflake, Message] attachments: Dict[Snowflake, Attachment] @@ -258,9 +258,10 @@ class _BaseInteraction(TypedDict): # common properties in non-ping interactions class _BaseUserInteraction(_BaseInteraction): - # the docs specify `channel_id` as optional, - # but it is assumed to always exist on non-ping interactions + # the docs specify `channel_id` and 'channel` as optional, + # but they're assumed to always exist on non-ping interactions channel_id: Snowflake + channel: InteractionChannel locale: str app_permissions: NotRequired[str] guild_id: NotRequired[Snowflake] diff --git a/tests/interactions/test_base.py b/tests/interactions/test_base.py index 24d937b685..5e364072dc 100644 --- a/tests/interactions/test_base.py +++ b/tests/interactions/test_base.py @@ -8,12 +8,12 @@ import pytest import disnake -from disnake import InteractionResponseType as ResponseType # shortcut +from disnake import Interaction, InteractionResponseType as ResponseType # shortcut from disnake.state import ConnectionState from disnake.utils import MISSING if TYPE_CHECKING: - from disnake.types.interactions import ResolvedPartialChannel as ResolvedPartialChannelPayload + from disnake.types.interactions import InteractionChannel as InteractionChannelPayload from disnake.types.member import Member as MemberPayload from disnake.types.user import User as UserPayload @@ -137,7 +137,14 @@ def state(self): s._get_guild.return_value = None return s - def test_init_member(self, state) -> None: + @pytest.fixture + def interaction(self, state): + i = mock.Mock(spec_set=Interaction) + i._state = state + i.guild_id = 1234 + return i + + def test_init_member(self, interaction) -> None: member_payload: MemberPayload = { "roles": [], "joined_at": "2022-09-02T22:00:55.069000+00:00", @@ -156,8 +163,7 @@ def test_init_member(self, state) -> None: # user only, should deserialize user object resolved = disnake.InteractionDataResolved( data={"users": {"1234": user_payload}}, - state=state, - guild_id=1234, + parent=interaction, ) assert len(resolved.members) == 0 assert len(resolved.users) == 1 @@ -165,8 +171,7 @@ def test_init_member(self, state) -> None: # member only, shouldn't deserialize anything resolved = disnake.InteractionDataResolved( data={"members": {"1234": member_payload}}, - state=state, - guild_id=1234, + parent=interaction, ) assert len(resolved.members) == 0 assert len(resolved.users) == 0 @@ -174,15 +179,14 @@ def test_init_member(self, state) -> None: # user + member, should deserialize member object only resolved = disnake.InteractionDataResolved( data={"users": {"1234": user_payload}, "members": {"1234": member_payload}}, - state=state, - guild_id=1234, + parent=interaction, ) assert len(resolved.members) == 1 assert len(resolved.users) == 0 - @pytest.mark.parametrize("channel_type", [t.value for t in disnake.ChannelType]) + @pytest.mark.parametrize("channel_type", [t.value for t in disnake.ChannelType] + [99]) def test_channel(self, state, channel_type) -> None: - channel_data: ResolvedPartialChannelPayload = { + channel_data: InteractionChannelPayload = { "id": "42", "type": channel_type, "permissions": "7", @@ -197,12 +201,14 @@ def test_channel(self, state, channel_type) -> None: "locked": False, } - resolved = disnake.InteractionDataResolved( - data={"channels": {"42": channel_data}}, state=state, guild_id=1234 + # this should not raise + channel = ConnectionState._get_partial_interaction_channel( + state, + channel_data, + disnake.Object(1234), + return_messageable=False, ) - assert len(resolved.channels) == 1 - channel = next(iter(resolved.channels.values())) - # should be partial if and only if it's a dm/group - # TODO: currently includes directory channels (14), see `InteractionDataResolved.__init__` - assert isinstance(channel, disnake.PartialMessageable) == (channel_type in (1, 3, 14)) + # should be partial if and only if it's a dm/group or unknown + # TODO: currently includes directory channels (14), see `_get_partial_interaction_channel` + assert isinstance(channel, disnake.PartialMessageable) == (channel_type in (1, 3, 14, 99)) From 0a5ab1e3c228c1a5ca83db836c709a89122a0f44 Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Sat, 24 Aug 2024 14:18:35 +0200 Subject: [PATCH 10/14] feat(message): include member data in `InteractionReference.user` (#1160) --- changelog/1160.feature.rst | 1 + disnake/message.py | 55 ++++++++++++++++++++++++++--------- disnake/state.py | 6 ++++ disnake/types/interactions.py | 1 + 4 files changed, 50 insertions(+), 13 deletions(-) create mode 100644 changelog/1160.feature.rst diff --git a/changelog/1160.feature.rst b/changelog/1160.feature.rst new file mode 100644 index 0000000000..973453cc21 --- /dev/null +++ b/changelog/1160.feature.rst @@ -0,0 +1 @@ +:attr:`InteractionReference.user` can now be a :class:`Member` in guild contexts. diff --git a/disnake/message.py b/disnake/message.py index 15fe00c5a9..9d6c16d56d 100644 --- a/disnake/message.py +++ b/disnake/message.py @@ -704,24 +704,48 @@ class InteractionReference: For interaction references created before July 18th, 2022, this will not include group or subcommand names. - user: :class:`User` - The interaction author. + user: Union[:class:`User`, :class:`Member`] + The user or member that triggered the referenced interaction. + + .. versionchanged:: 2.10 + This is now a :class:`Member` when in a guild, if the message was received via a + gateway event or the member is cached. """ - __slots__ = ("id", "type", "name", "user", "_state") + __slots__ = ("id", "type", "name", "user") - def __init__(self, *, state: ConnectionState, data: InteractionMessageReferencePayload) -> None: - self._state: ConnectionState = state + def __init__( + self, + *, + state: ConnectionState, + guild: Optional[Guild], + data: InteractionMessageReferencePayload, + ) -> None: self.id: int = int(data["id"]) self.type: InteractionType = try_enum(InteractionType, int(data["type"])) self.name: str = data["name"] - self.user: User = User(state=state, data=data["user"]) + + user: Optional[Union[User, Member]] = None + if guild: + if isinstance(guild, Guild): # this can be a placeholder object in interactions + user = guild.get_member(int(data["user"]["id"])) + + # If not cached, try data from event. + # This is only available via gateway (message_create/_edit), not HTTP + if not user and (member := data.get("member")): + user = Member(data=member, user_data=data["user"], guild=guild, state=state) + + # If still none, deserialize user + if not user: + user = state.store_user(data["user"]) + + self.user: Union[User, Member] = user def __repr__(self) -> str: return f"" @property - def author(self) -> User: + def author(self) -> Union[User, Member]: return self.user @@ -1003,12 +1027,6 @@ def __init__( for d in data.get("components", []) ] - inter_payload = data.get("interaction") - inter = ( - None if inter_payload is None else InteractionReference(state=state, data=inter_payload) - ) - self.interaction: Optional[InteractionReference] = inter - self.poll: Optional[Poll] = None if poll_data := data.get("poll"): self.poll = Poll.from_dict(message=self, data=poll_data) @@ -1019,6 +1037,12 @@ def __init__( except AttributeError: self.guild = state._get_guild(utils._get_as_snowflake(data, "guild_id")) + self.interaction: Optional[InteractionReference] = ( + InteractionReference(state=state, guild=self.guild, data=interaction) + if (interaction := data.get("interaction")) + else None + ) + if thread_data := data.get("thread"): if not self.thread and isinstance(self.guild, Guild): self.guild._store_thread(thread_data) @@ -1236,8 +1260,13 @@ def _handle_components(self, components: List[ComponentPayload]) -> None: def _rebind_cached_references(self, new_guild: Guild, new_channel: GuildMessageable) -> None: self.guild = new_guild self.channel = new_channel + + # rebind the members' guilds; the members themselves will potentially be + # updated later in _update_member_references, after re-chunking if isinstance(self.author, Member): self.author.guild = new_guild + if self.interaction and isinstance(self.interaction.user, Member): + self.interaction.user.guild = new_guild @utils.cached_slot_property("_cs_raw_mentions") def raw_mentions(self) -> List[int]: diff --git a/disnake/state.py b/disnake/state.py index 84798d2fa5..222b2ad3ee 100644 --- a/disnake/state.py +++ b/disnake/state.py @@ -2258,6 +2258,7 @@ def _update_guild_channel_references(self) -> None: if new_guild is None: continue + # TODO: use PartialMessageable instead of Object (3.0) new_channel = new_guild._resolve_channel(vc.channel.id) or Object(id=vc.channel.id) if new_channel is not vc.channel: vc.channel = new_channel # type: ignore @@ -2275,6 +2276,11 @@ def _update_member_references(self) -> None: if new_author is not None and new_author is not msg.author: msg.author = new_author + if msg.interaction is not None and isinstance(msg.interaction.user, Member): + new_author = msg.guild.get_member(msg.interaction.user.id) + if new_author is not None and new_author is not msg.interaction.user: + msg.interaction.user = new_author + async def chunker( self, guild_id: int, diff --git a/disnake/types/interactions.py b/disnake/types/interactions.py index 88498da81f..9cb8393ea5 100644 --- a/disnake/types/interactions.py +++ b/disnake/types/interactions.py @@ -335,6 +335,7 @@ class InteractionMessageReference(TypedDict): type: InteractionType name: str user: User + member: NotRequired[Member] class EditApplicationCommand(TypedDict): From 2206241fe87584c285b6cee5d7dac438547407d3 Mon Sep 17 00:00:00 2001 From: Snipy7374 <100313469+Snipy7374@users.noreply.github.com> Date: Sat, 31 Aug 2024 12:30:20 +0200 Subject: [PATCH 11/14] fix(cog): fix typing of cog check methods (#1232) --- disnake/ext/commands/cog.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/disnake/ext/commands/cog.py b/disnake/ext/commands/cog.py index 01fd59937c..bc1cff6e9a 100644 --- a/disnake/ext/commands/cog.py +++ b/disnake/ext/commands/cog.py @@ -34,6 +34,7 @@ from disnake.interactions import ApplicationCommandInteraction + from ._types import MaybeCoro from .bot import AutoShardedBot, AutoShardedInteractionBot, Bot, InteractionBot from .context import Context from .core import Command @@ -491,7 +492,7 @@ def cog_unload(self) -> None: pass @_cog_special_method - def bot_check_once(self, ctx: Context) -> bool: + def bot_check_once(self, ctx: Context) -> MaybeCoro[bool]: """A special method that registers as a :meth:`.Bot.check_once` check. @@ -503,7 +504,7 @@ def bot_check_once(self, ctx: Context) -> bool: return True @_cog_special_method - def bot_check(self, ctx: Context) -> bool: + def bot_check(self, ctx: Context) -> MaybeCoro[bool]: """A special method that registers as a :meth:`.Bot.check` check. @@ -515,7 +516,7 @@ def bot_check(self, ctx: Context) -> bool: return True @_cog_special_method - def bot_slash_command_check_once(self, inter: ApplicationCommandInteraction) -> bool: + def bot_slash_command_check_once(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]: """A special method that registers as a :meth:`.Bot.slash_command_check_once` check. @@ -525,7 +526,7 @@ def bot_slash_command_check_once(self, inter: ApplicationCommandInteraction) -> return True @_cog_special_method - def bot_slash_command_check(self, inter: ApplicationCommandInteraction) -> bool: + def bot_slash_command_check(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]: """A special method that registers as a :meth:`.Bot.slash_command_check` check. @@ -535,27 +536,29 @@ def bot_slash_command_check(self, inter: ApplicationCommandInteraction) -> bool: return True @_cog_special_method - def bot_user_command_check_once(self, inter: ApplicationCommandInteraction) -> bool: + def bot_user_command_check_once(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]: """Similar to :meth:`.Bot.slash_command_check_once` but for user commands.""" return True @_cog_special_method - def bot_user_command_check(self, inter: ApplicationCommandInteraction) -> bool: + def bot_user_command_check(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]: """Similar to :meth:`.Bot.slash_command_check` but for user commands.""" return True @_cog_special_method - def bot_message_command_check_once(self, inter: ApplicationCommandInteraction) -> bool: + def bot_message_command_check_once( + self, inter: ApplicationCommandInteraction + ) -> MaybeCoro[bool]: """Similar to :meth:`.Bot.slash_command_check_once` but for message commands.""" return True @_cog_special_method - def bot_message_command_check(self, inter: ApplicationCommandInteraction) -> bool: + def bot_message_command_check(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]: """Similar to :meth:`.Bot.slash_command_check` but for message commands.""" return True @_cog_special_method - def cog_check(self, ctx: Context) -> bool: + def cog_check(self, ctx: Context) -> MaybeCoro[bool]: """A special method that registers as a :func:`~.check` for every text command and subcommand in this cog. @@ -567,7 +570,7 @@ def cog_check(self, ctx: Context) -> bool: return True @_cog_special_method - def cog_slash_command_check(self, inter: ApplicationCommandInteraction) -> bool: + def cog_slash_command_check(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]: """A special method that registers as a :func:`~.check` for every slash command and subcommand in this cog. @@ -577,12 +580,12 @@ def cog_slash_command_check(self, inter: ApplicationCommandInteraction) -> bool: return True @_cog_special_method - def cog_user_command_check(self, inter: ApplicationCommandInteraction) -> bool: + def cog_user_command_check(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]: """Similar to :meth:`.Cog.cog_slash_command_check` but for user commands.""" return True @_cog_special_method - def cog_message_command_check(self, inter: ApplicationCommandInteraction) -> bool: + def cog_message_command_check(self, inter: ApplicationCommandInteraction) -> MaybeCoro[bool]: """Similar to :meth:`.Cog.cog_slash_command_check` but for message commands.""" return True From c7123d6e3d4c2caab9149e7870f7f37a10f6c924 Mon Sep 17 00:00:00 2001 From: Snipy7374 <100313469+Snipy7374@users.noreply.github.com> Date: Tue, 3 Sep 2024 18:16:51 +0200 Subject: [PATCH 12/14] fix(changelog): fix poll changelog (#1234) --- changelog/1175.feature.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/changelog/1175.feature.rst b/changelog/1175.feature.rst index 78dd79b311..5c839ba33f 100644 --- a/changelog/1175.feature.rst +++ b/changelog/1175.feature.rst @@ -1,6 +1,6 @@ Add the new poll discord API feature. This includes the following new classes and events: -- New types: :class:`Poll`, :class:`PollAnswer`, :class:`PollMedia`, :class:`RawMessagePollVoteActionEvent` and :class:`PollLayoutType`. +- New types: :class:`Poll`, :class:`PollAnswer`, :class:`PollMedia`, :class:`RawPollVoteActionEvent` and :class:`PollLayoutType`. - Edited :meth:`abc.Messageable.send`, :meth:`Webhook.send`, :meth:`ext.commands.Context.send` and :meth:`disnake.InteractionResponse.send_message` to be able to send polls. - Edited :class:`Message` to store a new :attr:`Message.poll` attribute for polls. -- Edited :class:`Event` to contain the new :func:`on_message_poll_vote_add`, :func:`on_message_poll_vote_remove`, :func:`on_raw_message_poll_vote_add` and :func:`on_raw_message_poll_vote_remove`. +- Edited :class:`Event` to contain the new :func:`on_poll_vote_add`, :func:`on_poll_vote_remove`, :func:`on_raw_poll_vote_add` and :func:`on_raw_poll_vote_remove`. From 4a5475c22047c1837cc97aa5ea8ed7e2e20a9db6 Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Thu, 5 Sep 2024 18:57:29 +0200 Subject: [PATCH 13/14] docs: update build for new readthedocs addons (#1207) --- docs/_static/style.css | 4 ++-- docs/_templates/layout.html | 10 ---------- docs/conf.py | 24 +++++++++++------------- 3 files changed, 13 insertions(+), 25 deletions(-) diff --git a/docs/_static/style.css b/docs/_static/style.css index b89e43a930..fb9bece8b5 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -103,7 +103,7 @@ Historically however, thanks to: --rtd-ad-background: var(--grey-2); --rtd-ad-main-text: var(--grey-6); --rtd-ad-small-text: var(--grey-4); - --rtd-version-background: #272525; + --rtd-version-background: #272725; --rtd-version-main-text: #fcfcfc; --attribute-table-title: var(--grey-6); --attribute-table-list-border: var(--grey-3); @@ -826,7 +826,7 @@ section h3 { } #to-top.is-rtd { - bottom: 90px; + bottom: 100px; } #to-top > span { diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html index ef97d73439..305ec42b45 100644 --- a/docs/_templates/layout.html +++ b/docs/_templates/layout.html @@ -156,16 +156,6 @@ {%- endblock %} - {%- if READTHEDOCS %} - - {%- endif %}