Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement __eq__ between BaseFlags and flag_values #1238

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/1238.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for ``BaseFlags`` to allow comparison with ``flag_values`` and vice versa.
16 changes: 15 additions & 1 deletion disnake/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def __init__(self, func: Callable[[Any], int]) -> None:
self.__doc__ = func.__doc__
self._parent: Type[T] = MISSING

def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
return self.flag == other.flag
if isinstance(other, BaseFlags):
return self._parent is other.__class__ and self.flag == other.value
return False

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)

def __or__(self, other: Union[flag_value[T], T]) -> T:
if isinstance(other, BaseFlags):
if self._parent is not other.__class__:
Expand Down Expand Up @@ -148,7 +158,11 @@ def _from_value(cls, value: int) -> Self:
return self

def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and self.value == other.value
if isinstance(other, self.__class__):
return self.value == other.value
if isinstance(other, flag_value):
return self.__class__ is other._parent and self.value == other.flag
return False

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,21 @@ def test__eq__(self) -> None:
assert not ins == other
assert ins != other

def test__eq__flag_value(self) -> None:
ins = TestFlags(one=True)
other = TestFlags(one=True, two=True)

assert ins == TestFlags.one
assert TestFlags.one == ins

assert not ins != TestFlags.one
assert ins != TestFlags.two

assert other != TestFlags.one
assert other != TestFlags.two

assert other == TestFlags.three

def test__and__(self) -> None:
ins = TestFlags(one=True, two=True)
other = TestFlags(one=True, two=True)
Expand Down
Loading