Skip to content

Commit

Permalink
Add better type deduction for mapping types
Browse files Browse the repository at this point in the history
  • Loading branch information
dosisod committed Mar 1, 2024
1 parent c30b449 commit d8ec9f4
Show file tree
Hide file tree
Showing 13 changed files with 79 additions and 34 deletions.
5 changes: 2 additions & 3 deletions refurb/checks/builtin/no_ignored_dict_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

from refurb.checks.common import (
check_for_loop_like,
get_mypy_type,
is_mapping,
is_name_unused_in_contexts,
is_same_type,
stringify,
)
from refurb.error import Error
Expand Down Expand Up @@ -75,7 +74,7 @@ def check_dict_items_call(
callee=MemberExpr(expr=dict_expr, name="items"),
args=[],
),
) if is_same_type(get_mypy_type(dict_expr), dict):
) if is_mapping(dict_expr):
check_unused_key_or_value(key, value, contexts, errors, dict_expr)


Expand Down
32 changes: 32 additions & 0 deletions refurb/checks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,3 +721,35 @@ def mypy_type_to_python_type(ty: Type | SymbolNode | None) -> type | None:
return SIMPLE_TYPES.get(fullname) # type: ignore

return None # pragma: no cover


MAPPING_TYPES = (
dict,
"collections.ChainMap",
"collections.Counter",
"collections.OrderedDict",
"collections.UserDict",
"collections.abc.Mapping",
"collections.abc.MutableMapping",
"collections.defaultdict",
"typing.Mapping",
"typing.MutableMapping",
)


# TODO: support any Mapping subclass
def is_mapping(expr: Expression) -> bool:
return is_mapping_type(get_mypy_type(expr))


def is_mapping_type(ty: Type | SymbolNode | None) -> bool:
return is_same_type(ty, *MAPPING_TYPES)


def is_sized(node: Expression) -> bool:
return is_sized_type(get_mypy_type(node))


# TODO: support any Sized subclass
def is_sized_type(ty: Type | SymbolNode | None) -> bool:
return is_mapping_type(ty) or is_same_type(ty, list, tuple, set, frozenset, str)
4 changes: 2 additions & 2 deletions refurb/checks/readability/in_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from mypy.nodes import CallExpr, ComparisonExpr, MemberExpr

from refurb.checks.common import get_mypy_type, is_same_type, stringify
from refurb.checks.common import is_mapping, stringify
from refurb.error import Error


Expand Down Expand Up @@ -47,7 +47,7 @@ def check(node: ComparisonExpr, errors: list[Error]) -> None:
args=[],
) as expr,
],
) if is_same_type(get_mypy_type(dict_expr), dict):
) if is_mapping(dict_expr):
dict_expr = stringify(dict_expr) # type: ignore

msg = f"Replace `{oper} {stringify(expr)}` with `{oper} {dict_expr}`"
Expand Down
14 changes: 3 additions & 11 deletions refurb/checks/readability/no_len_cmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
ConditionalExpr,
DictExpr,
DictionaryComprehension,
Expression,
GeneratorExpr,
IfStmt,
IntExpr,
Expand All @@ -20,7 +19,7 @@
WhileStmt,
)

from refurb.checks.common import get_mypy_type, is_same_type, stringify
from refurb.checks.common import is_sized, stringify
from refurb.error import Error
from refurb.visitor import METHOD_NODE_MAPPINGS, TraverserVisitor

Expand Down Expand Up @@ -61,16 +60,9 @@ class ErrorInfo(Error):
categories = ("iterable", "truthy")


def is_builtin_container_like(node: Expression) -> bool:
return is_same_type(get_mypy_type(node), list, tuple, dict, set, frozenset, str)


def is_len_call(node: CallExpr) -> bool:
match node:
case CallExpr(
callee=NameExpr(fullname="builtins.len"),
args=[arg],
) if is_builtin_container_like(arg):
case CallExpr(callee=NameExpr(fullname="builtins.len"), args=[arg]) if is_sized(arg):
return True

return False
Expand Down Expand Up @@ -130,7 +122,7 @@ def visit_comparison_expr(self, node: ComparisonExpr) -> None:
case ComparisonExpr(
operators=["==" | "!=" as oper],
operands=[lhs, (ListExpr(items=[]) | DictExpr(items=[]))],
) if is_builtin_container_like(lhs):
) if is_sized(lhs):
old = stringify(node)
new = stringify(lhs)

Expand Down
22 changes: 4 additions & 18 deletions refurb/checks/readability/use_dict_union.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass
from itertools import groupby

from mypy.nodes import ArgKind, CallExpr, DictExpr, Expression, RefExpr
from mypy.nodes import ArgKind, CallExpr, DictExpr, RefExpr

from refurb.checks.common import get_mypy_type, is_same_type, stringify
from refurb.checks.common import is_mapping, stringify
from refurb.error import Error
from refurb.settings import Settings

Expand Down Expand Up @@ -38,20 +38,6 @@ def add_defaults(settings: dict[str, str]) -> dict[str, str]:
categories = ("dict", "readability")


MAPPING_TYPES = (
dict,
"collections.ChainMap",
"collections.Counter",
"collections.OrderedDict",
"collections.defaultdict",
"collections.UserDict",
)


def is_builtin_mapping(expr: Expression) -> bool:
return is_same_type(get_mypy_type(expr), *MAPPING_TYPES)


def check(node: DictExpr | CallExpr, errors: list[Error], settings: Settings) -> None:
if settings.get_python_version() < (3, 9):
return # pragma: no cover
Expand Down Expand Up @@ -83,7 +69,7 @@ def check(node: DictExpr | CallExpr, errors: list[Error], settings: Settings) ->
if is_star:
_, star_expr = pair

if not is_builtin_mapping(star_expr):
if not is_mapping(star_expr):
return

old.append(f"**{stringify(star_expr)}")
Expand Down Expand Up @@ -121,7 +107,7 @@ def check(node: DictExpr | CallExpr, errors: list[Error], settings: Settings) ->
return

if kind == ArgKind.ARG_STAR2:
if not is_builtin_mapping(arg):
if not is_mapping(arg):
return

stringified_arg = stringify(arg)
Expand Down
7 changes: 7 additions & 0 deletions test/data/err_115.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ class C:
assert c.l == []


from collections.abc import Mapping

def mapping_check(m: Mapping[str, str]):
if len(m) == 0:
pass


# these should not

if len(nums) == 1: ...
Expand Down
1 change: 1 addition & 0 deletions test/data/err_115.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ test/data/err_115.py:73:8 [FURB115]: Replace `len(nums)` with `nums`
test/data/err_115.py:74:8 [FURB115]: Replace `len(nums)` with `nums`
test/data/err_115.py:80:8 [FURB115]: Replace `C().l == []` with `not C().l`
test/data/err_115.py:83:8 [FURB115]: Replace `c.l == []` with `not c.l`
test/data/err_115.py:89:8 [FURB115]: Replace `len(m) == 0` with `not m`
7 changes: 7 additions & 0 deletions test/data/err_130.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ class C:
pass


from collections.abc import Mapping

def mapping_check(m: Mapping[str, str]):
if x in m.keys():
pass


# these should not

if "key" in d:
Expand Down
1 change: 1 addition & 0 deletions test/data/err_130.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ test/data/err_130.py:5:13 [FURB130]: Replace `in d.keys()` with `in d`
test/data/err_130.py:8:17 [FURB130]: Replace `not in d.keys()` with `not in d`
test/data/err_130.py:12:9 [FURB130]: Replace `in d.keys()` with `in d`
test/data/err_130.py:18:9 [FURB130]: Replace `in C().d.keys()` with `in C().d`
test/data/err_130.py:25:13 [FURB130]: Replace `in m.keys()` with `in m`
7 changes: 7 additions & 0 deletions test/data/err_135.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def f6():
print(k, v)


from collections.abc import Mapping

def mapping_check(m: Mapping[str, str]):
for k, v in m.items():
pass


# these should not

def f7():
Expand Down
2 changes: 2 additions & 0 deletions test/data/err_135.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ test/data/err_135.py:39:15 [FURB135]: Value is unused, use `for k in c.d` instea
test/data/err_135.py:40:12 [FURB135]: Key is unused, use `for v in c.d.values()` instead
test/data/err_135.py:49:9 [FURB135]: Key is unused, use `for v in d.values()` instead
test/data/err_135.py:49:12 [FURB135]: Value is unused, use `for k in d` instead
test/data/err_135.py:58:9 [FURB135]: Key is unused, use `for v in m.values()` instead
test/data/err_135.py:58:12 [FURB135]: Value is unused, use `for k in m` instead
9 changes: 9 additions & 0 deletions test/data/err_173.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ class Wrapper:

_ = {**Wrapper().d, **x}

from collections.abc import Mapping, MutableMapping


def mapping_test(m: Mapping[str, str]):
_ = dict(**m)

def mutable_mapping_test(m: MutableMapping[str, str]):
_ = dict(**m)


# these should not

Expand Down
2 changes: 2 additions & 0 deletions test/data/err_173.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ test/data/err_173.py:41:5 [FURB173]: Replace `dict(**x, a=1, b=2)` with `x | {"a
test/data/err_173.py:42:5 [FURB173]: Replace `dict(**x, **y, a=1, b=2)` with `x | y | {"a": 1, "b": 2}`
test/data/err_173.py:43:5 [FURB173]: Replace `dict(**x, **{})` with `x | {}`
test/data/err_173.py:48:5 [FURB173]: Replace `{**Wrapper().d, **x}` with `Wrapper().d | x`
test/data/err_173.py:54:9 [FURB173]: Replace `dict(**m)` with `{**m}`
test/data/err_173.py:57:9 [FURB173]: Replace `dict(**m)` with `{**m}`

0 comments on commit d8ec9f4

Please sign in to comment.