Skip to content

Commit

Permalink
Add helper functions for matching True/False/None literals
Browse files Browse the repository at this point in the history
  • Loading branch information
dosisod committed Mar 22, 2024
1 parent 714c564 commit 895109c
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 52 deletions.
10 changes: 4 additions & 6 deletions refurb/checks/builtin/no_sorted_min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from mypy.nodes import CallExpr, IndexExpr, IntExpr, NameExpr, UnaryExpr

from refurb.checks.common import stringify
from refurb.checks.common import is_true_literal, stringify
from refurb.error import Error


Expand Down Expand Up @@ -52,12 +52,10 @@ def check(node: IndexExpr, errors: list[Error]) -> None:

for arg_name, arg in zip(arg_names, args):
if arg_name == "reverse":
match arg:
case NameExpr(fullname="builtins.True"):
is_reversed = True
if not is_true_literal(arg):
return

case _:
return
is_reversed = True

elif arg_name == "key":
key = f", key={stringify(arg)}"
Expand Down
33 changes: 25 additions & 8 deletions refurb/checks/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Callable
from itertools import chain, combinations, starmap
from typing import Any
from typing import Any, TypeGuard

from mypy.nodes import (
ArgKind,
Expand Down Expand Up @@ -285,13 +285,17 @@ def is_type_none_call(node: Expression) -> bool:
match node:
case CallExpr(
callee=NameExpr(fullname="builtins.type"),
args=[NameExpr(fullname="builtins.None")],
):
args=[arg],
) if is_none_literal(arg):
return True

return False


def is_none_literal(node: Node) -> TypeGuard[NameExpr]:
return isinstance(node, NameExpr) and node.fullname == "builtins.None"


def get_fstring_parts(expr: Expression) -> list[tuple[bool, Expression, str]]:
match expr:
case CallExpr(
Expand Down Expand Up @@ -623,8 +627,12 @@ def get_mypy_type(node: Node) -> Type | SymbolNode | None:
case ComplexExpr():
return _get_builtin_mypy_type("complex")

case NameExpr(fullname="builtins.True" | "builtins.False"):
return _get_builtin_mypy_type("bool")
case NameExpr():
if is_bool_literal(node):
return _get_builtin_mypy_type("bool")

if node.node:
return get_mypy_type(node.node)

case DictExpr():
return _get_builtin_mypy_type("dict")
Expand All @@ -644,9 +652,6 @@ def get_mypy_type(node: Node) -> Type | SymbolNode | None:
case TypeInfo() | TypeAlias() | MypyFile():
return node

case NameExpr(node=sym) if sym:
return get_mypy_type(sym)

case MemberExpr(expr=lhs, name=name):
ty = get_mypy_type(lhs)

Expand Down Expand Up @@ -763,3 +768,15 @@ def is_sized_type(ty: Type | SymbolNode | None) -> bool:
"_collections_abc.dict_keys",
"_collections_abc.dict_values",
)


def is_bool_literal(node: Node) -> TypeGuard[NameExpr]:
return is_true_literal(node) or is_false_literal(node)


def is_true_literal(node: Node) -> TypeGuard[NameExpr]:
return isinstance(node, NameExpr) and node.fullname == "builtins.True"


def is_false_literal(node: Node) -> TypeGuard[NameExpr]:
return isinstance(node, NameExpr) and node.fullname == "builtins.False"
7 changes: 4 additions & 3 deletions refurb/checks/functools/use_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass

from mypy.nodes import ArgKind, CallExpr, Decorator, MemberExpr, NameExpr, RefExpr
from mypy.nodes import ArgKind, CallExpr, Decorator, MemberExpr, RefExpr

from refurb.checks.common import is_none_literal
from refurb.error import Error
from refurb.settings import Settings

Expand Down Expand Up @@ -50,10 +51,10 @@ def check(node: Decorator, errors: list[Error], settings: Settings) -> None:
callee=RefExpr(fullname="functools.lru_cache") as ref,
arg_names=["maxsize"],
arg_kinds=[ArgKind.ARG_NAMED],
args=[NameExpr(fullname="builtins.None")],
args=[arg],
)
]
):
) if is_none_literal(arg):
prefix = "functools." if isinstance(ref, MemberExpr) else ""
old = f"@{prefix}lru_cache(maxsize=None)"
new = f"@{prefix}cache"
Expand Down
13 changes: 2 additions & 11 deletions refurb/checks/readability/no_is_bool_compare.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from dataclasses import dataclass
from typing import TypeGuard

from mypy.nodes import ComparisonExpr, Expression, NameExpr
from mypy.nodes import ComparisonExpr, Expression

from refurb.checks.common import get_mypy_type, is_same_type, stringify
from refurb.checks.common import get_mypy_type, is_bool_literal, is_same_type, stringify
from refurb.error import Error


Expand Down Expand Up @@ -37,14 +36,6 @@ class ErrorInfo(Error):
categories = ("logical", "readability", "truthy")


def is_bool_literal(expr: Expression) -> TypeGuard[NameExpr]:
match expr:
case NameExpr(fullname="builtins.True" | "builtins.False"):
return True

return False


def is_bool_variable(expr: Expression) -> bool:
return is_same_type(get_mypy_type(expr), bool)

Expand Down
38 changes: 16 additions & 22 deletions refurb/checks/readability/use_isinstance_bool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from dataclasses import dataclass

from mypy.nodes import ComparisonExpr, Expression, ListExpr, NameExpr, OpExpr, SetExpr, TupleExpr

from refurb.checks.common import extract_binary_oper, is_equivalent, stringify
from mypy.nodes import ComparisonExpr, ListExpr, OpExpr, SetExpr, TupleExpr

from refurb.checks.common import (
extract_binary_oper,
is_equivalent,
is_false_literal,
is_true_literal,
stringify,
)
from refurb.error import Error


Expand Down Expand Up @@ -32,23 +38,6 @@ class ErrorInfo(Error):
categories = ("readability",)


# TODO: move to common
def is_true(expr: Expression) -> bool:
match expr:
case NameExpr(fullname="builtins.True"):
return True

return False


def is_false(expr: Expression) -> bool:
match expr:
case NameExpr(fullname="builtins.False"):
return True

return False


def check(node: ComparisonExpr | OpExpr, errors: list[Error]) -> None:
match node:
case ComparisonExpr(
Expand All @@ -57,7 +46,9 @@ def check(node: ComparisonExpr | OpExpr, errors: list[Error]) -> None:
lhs,
SetExpr(items=[t, f]) | TupleExpr(items=[t, f]) | ListExpr(items=[t, f]),
],
) if (is_true(t) and is_false(f)) or (is_false(t) and is_true(f)):
) if (is_true_literal(t) and is_false_literal(f)) or (
is_false_literal(t) and is_true_literal(f)
):
old = stringify(node)
new = f"isinstance({stringify(lhs)}, bool)"

Expand All @@ -76,7 +67,10 @@ def check(node: ComparisonExpr | OpExpr, errors: list[Error]) -> None:
) if (
lhs_op == rhs_op
and is_equivalent(lhs, rhs)
and ((is_true(t) and is_false(f)) or (is_false(t) and is_true(f)))
and (
(is_true_literal(t) and is_false_literal(f))
or (is_false_literal(t) and is_true_literal(f))
)
):
old = stringify(node)
new = f"isinstance({stringify(lhs)}, bool)"
Expand Down
4 changes: 2 additions & 2 deletions refurb/checks/secrets/simplify_token_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from mypy.nodes import CallExpr, IndexExpr, IntExpr, MemberExpr, NameExpr, RefExpr, SliceExpr

from refurb.checks.common import stringify
from refurb.checks.common import is_none_literal, stringify
from refurb.error import Error


Expand Down Expand Up @@ -49,7 +49,7 @@ def check(node: CallExpr | IndexExpr, errors: list[Error]) -> None:
case [IntExpr(value=value)]:
arg = str(value)

case [NameExpr(fullname="builtins.None")]:
case [arg] if is_none_literal(arg):
arg = "None"

case []:
Expand Down

0 comments on commit 895109c

Please sign in to comment.