Skip to content

Commit

Permalink
Add better type deduction for more expressions:
Browse files Browse the repository at this point in the history
* Deduce type of walrus operator (`:=`) expressions
* Deduce type of awaited expressions
* Deduce type of `cast()` functions
* Deduce `set` literals as type `set`
* Deduce type for unary operations
* Deduce type of index (`x[0]`) expressions
* Deduce return type when calling lambda expressions
  • Loading branch information
dosisod committed Feb 17, 2024
1 parent 5a6d61a commit 5c15349
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 32 deletions.
110 changes: 78 additions & 32 deletions refurb/checks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

from mypy.nodes import (
ArgKind,
AssignmentExpr,
AssignmentStmt,
AwaitExpr,
Block,
BytesExpr,
CallExpr,
CastExpr,
ComparisonExpr,
ComplexExpr,
ConditionalExpr,
Expand Down Expand Up @@ -36,7 +39,7 @@
StarExpr,
Statement,
StrExpr,
SymbolTableNode,
SymbolNode,
TupleExpr,
TypeAlias,
TypeInfo,
Expand Down Expand Up @@ -445,6 +448,12 @@ def _stringify(node: Node) -> str:
case ExpressionStmt(expr=expr):
return _stringify(expr)

case AwaitExpr(expr=expr):
return f"await {_stringify(expr)}"

case AssignmentExpr(target=lhs, value=rhs):
return f"{_stringify(lhs)} := {_stringify(rhs)}"

raise ValueError


Expand All @@ -463,7 +472,7 @@ def slice_expr_to_slice_call(expr: SliceExpr) -> str:
TypeLike = type | str | None | object


def is_same_type(ty: Type | TypeInfo | None, *expected: TypeLike) -> bool:
def is_same_type(ty: Type | SymbolNode | None, *expected: TypeLike) -> bool:
"""
Check if the type `ty` matches any of the `expected` types. `ty` must be a Mypy type object,
but the expected types can be any of the following:
Expand Down Expand Up @@ -498,7 +507,7 @@ def is_same_type(ty: Type | TypeInfo | None, *expected: TypeLike) -> bool:
}


def _is_same_type(ty: Type | TypeInfo | None, expected: TypeLike) -> bool:
def _is_same_type(ty: Type | SymbolNode | None, expected: TypeLike) -> bool:
if ty is expected is None:
return True

Expand All @@ -520,14 +529,17 @@ def _is_same_type(ty: Type | TypeInfo | None, expected: TypeLike) -> bool:
return False


def _get_builtin_mypy_type(name: str) -> Type | None:
def _get_builtin_mypy_type(name: str) -> Instance | None:
if (sym := types.BUILTINS_MYPY_FILE.names.get(name)) and isinstance(sym.node, TypeInfo):
return Instance(sym.node, [])

return None # pragma: no cover


def get_mypy_type(node: Node) -> Type | None:
def get_mypy_type(node: Node) -> Type | SymbolNode | None:
# forward declaration to make Mypy happy
ty: Type | SymbolNode | None

match node:
case StrExpr():
return _get_builtin_mypy_type("str")
Expand Down Expand Up @@ -556,52 +568,86 @@ def get_mypy_type(node: Node) -> Type | None:
case TupleExpr():
return _get_builtin_mypy_type("tuple")

case Var(type=ty):
case SetExpr():
return _get_builtin_mypy_type("set")

case Var(type=ty) | FuncDef(type=ty):
return ty

case NameExpr(node=sym):
match sym:
case Var(type=ty) | Instance(type=ty): # type: ignore
return ty
case TypeInfo() | TypeAlias() | MypyFile():
return node

case TypeAlias(target=ty):
case NameExpr(node=sym) if sym:
return get_mypy_type(sym)

case MemberExpr(expr=lhs, name=name):
ty = get_mypy_type(lhs)

if (
isinstance(ty, MypyFile | TypeInfo)
and (member := ty.names.get(name))
and member.node
):
return get_mypy_type(member.node)

if isinstance(ty, Instance) and (member := ty.type.get(name)) and member.node:
return get_mypy_type(member.node)

case CallExpr(analyzed=CastExpr(type=ty)):
return ty

case CallExpr(callee=callee):
match get_mypy_type(callee):
case CallableType(ret_type=ty):
return ty

case FuncDef(type=CallableType(ret_type=ty)):
case TypeAlias(target=ty):
return ty

case TypeInfo():
case TypeInfo() as sym:
return Instance(sym, [])

case MemberExpr(expr=lhs, name=name):
# TODO: don't special case this
match lhs:
case NameExpr(node=MypyFile(names=names)):
match names.get(name):
case SymbolTableNode(node=FuncDef(type=CallableType(ret_type=ty))):
return ty
case UnaryExpr(op="not"):
return _get_builtin_mypy_type("bool")

case SymbolTableNode(node=TypeInfo() as ty): # type: ignore
return Instance(ty, []) # type: ignore
case UnaryExpr(method_type=CallableType(ret_type=ty)):
return ty

lhs_type = get_mypy_type(lhs)
case OpExpr(method_type=CallableType(ret_type=ty)):
return ty

if isinstance(lhs_type, Instance):
sym = lhs_type.type.get(name) # type: ignore
case IndexExpr(method_type=CallableType(ret_type=ty)):
return ty

if sym and sym.node: # type: ignore
return get_mypy_type(sym.node) # type: ignore
case AwaitExpr(expr=expr):
ty = get_mypy_type(expr)

case CallExpr(callee=callee):
return get_mypy_type(callee)
# TODO: allow for any Awaitable[T] type
match ty:
case Instance(type=TypeInfo(fullname="typing.Coroutine"), args=[_, _, rtype]):
return rtype

case OpExpr(method_type=CallableType(ret_type=ty)):
return ty
case Instance(type=TypeInfo(fullname="asyncio.tasks.Task"), args=[rtype]):
return rtype

case LambdaExpr(body=Block(body=[ReturnStmt(expr=expr)])) if expr:
if (ty := get_mypy_type(expr)) and isinstance(ty, Type):
return _build_placeholder_callable(ty)

case AssignmentExpr(target=expr):
return get_mypy_type(expr)

return None


def mypy_type_to_python_type(ty: Type | None) -> type | None:
def _build_placeholder_callable(rtype: Type) -> Type | None:
if function := _get_builtin_mypy_type("function"):
return CallableType([], [], [], ret_type=rtype, fallback=function)

return None # pragma: no cover


def mypy_type_to_python_type(ty: Type | SymbolNode | None) -> type | None:
match ty:
# TODO: return annotated types if instance has args (ie, `list[int]`)
case Instance(type=TypeInfo(fullname=fullname)):
Expand Down
1 change: 1 addition & 0 deletions refurb/checks/readability/no_unnecessary_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ErrorInfo(Error):
"builtins.float": (float, ""),
"builtins.int": (int, ""),
"builtins.list": (list, ".copy()"),
"builtins.set": (set, ""),
"builtins.str": (str, ""),
"builtins.tuple": (tuple, ""),
}
Expand Down
4 changes: 4 additions & 0 deletions test/data/err_123.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def func() -> bool:

_ = bool(func())

s = {1}
_ = set(s)
_ = set({1})


# these will not

Expand Down
2 changes: 2 additions & 0 deletions test/data/err_123.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ test/data/err_123.py:29:5 [FURB123]: Replace `list(f)` with `f.copy()`
test/data/err_123.py:32:5 [FURB123]: Replace `str(g)` with `g`
test/data/err_123.py:35:5 [FURB123]: Replace `tuple(t)` with `t`
test/data/err_123.py:40:5 [FURB123]: Replace `bool(func())` with `func()`
test/data/err_123.py:43:5 [FURB123]: Replace `set(s)` with `s`
test/data/err_123.py:44:5 [FURB123]: Replace `set({1})` with `{1}`
40 changes: 40 additions & 0 deletions test/data/type_deduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# These are a variety of checks to ensure Refurb is able to deduce types from
# complex expressions.

_ = bool([True][0])


async def async_wrapper():
import asyncio

async def return_bool() -> bool:
return True

task = asyncio.create_task(return_bool())

_ = bool(await return_bool())
_ = bool(await task)


lambda_return_bool = lambda: True
_ = bool(lambda_return_bool())
_ = bool((lambda: True)()) # TODO: error message should include parens around lambda

bool_value = True

_ = bool(not bool_value)
_ = bool(not False)

_ = int(-1)
_ = int(+1)
_ = int(~1)

_ = bool(walrus := True)

from typing import cast

_ = bool(cast(bool, 123))


# These types are not able to be deduced (yet)
_ = int(1 or 2)
12 changes: 12 additions & 0 deletions test/data/type_deduce.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
test/data/type_deduce.py:4:5 [FURB123]: Replace `bool([True][0])` with `[True][0]`
test/data/type_deduce.py:15:9 [FURB123]: Replace `bool(await return_bool())` with `await return_bool()`
test/data/type_deduce.py:16:9 [FURB123]: Replace `bool(await task)` with `await task`
test/data/type_deduce.py:20:5 [FURB123]: Replace `bool(lambda_return_bool())` with `lambda_return_bool()`
test/data/type_deduce.py:21:5 [FURB123]: Replace `bool(lambda: True())` with `lambda: True()`
test/data/type_deduce.py:25:5 [FURB123]: Replace `bool(not bool_value)` with `not bool_value`
test/data/type_deduce.py:26:5 [FURB123]: Replace `bool(not False)` with `not False`
test/data/type_deduce.py:28:5 [FURB123]: Replace `int(-1)` with `-1`
test/data/type_deduce.py:29:5 [FURB123]: Replace `int(+1)` with `+1`
test/data/type_deduce.py:30:5 [FURB123]: Replace `int(~1)` with `~1`
test/data/type_deduce.py:32:5 [FURB123]: Replace `bool(walrus := True)` with `walrus := True`
test/data/type_deduce.py:36:5 [FURB123]: Replace `bool(cast(bool, 123))` with `cast(bool, 123)`

0 comments on commit 5c15349

Please sign in to comment.