Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation for fluent interface check #287

Merged
merged 11 commits into from
Dec 22, 2023
167 changes: 167 additions & 0 deletions refurb/checks/readability/fluid_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from dataclasses import dataclass

from mypy.nodes import (
AssignmentStmt,
CallExpr,
Expression,
FuncDef,
MemberExpr,
NameExpr,
ReturnStmt,
Statement,
)

from refurb.checks.common import ReadCountVisitor, check_block_like
from refurb.error import Error
from refurb.visitor import TraverserVisitor


@dataclass
class ErrorInfo(Error):
r"""
When an API has a Fluent Interface (the ability to chain multiple calls together), you should
chain those calls instead of repeatedly assigning and using the value.
Sometimes a return statement can be written more succinctly:

Bad:

```python
def get_tensors(device: str) -> torch.Tensor:
t1 = torch.ones(2, 1)
t2 = t1.long()
t3 = t2.to(device)
return t3

def process(file_name: str):
common_columns = ["col1_renamed", "col2_renamed", "custom_col"]
df = spark.read.parquet(file_name)
df = df \
.withColumnRenamed('col1', 'col1_renamed') \
.withColumnRenamed('col2', 'col2_renamed')
df = df \
.select(common_columns) \
.withColumn('service_type', F.lit('green'))
return df
sbrugman marked this conversation as resolved.
Show resolved Hide resolved
```

Good:

```python
def get_tensors(device: str) -> torch.Tensor:
t3 = (
torch.ones(2, 1)
.long()
.to(device)
)
return t3

def process(file_name: str):
common_columns = ["col1_renamed", "col2_renamed", "custom_col"]
df = (
spark.read.parquet(file_name)
.withColumnRenamed('col1', 'col1_renamed')
.withColumnRenamed('col2', 'col2_renamed')
.select(common_columns)
.withColumn('service_type', F.lit('green'))
)
return df
```
"""

name = "use-fluid-interface"
code = 184
categories = ("readability",)


def check(node: FuncDef, errors: list[Error]) -> None:
check_block_like(check_stmts, node.body, errors)


def check_call(node: Expression, name: str | None = None) -> bool:
match node:
# Single chain
case CallExpr(callee=MemberExpr(expr=NameExpr(name=x), name=_)):
if name is None or name == x:
# Exclude other references
x_expr = NameExpr(x)
x_expr.fullname = x
visitor = ReadCountVisitor(x_expr)
visitor.accept(node)
return visitor.read_count == 1
return False

# Nested
case CallExpr(callee=MemberExpr(expr=call_node, name=_)):
return check_call(call_node, name=name)

return False


class NameReferenceVisitor(TraverserVisitor):
name: NameExpr
referenced: bool

def __init__(self, name: NameExpr, stmt: Statement | None = None) -> None:
super().__init__()
self.name = name
self.stmt = stmt
self.referenced = False

def visit_name_expr(self, node: NameExpr) -> None:
if not self.referenced and node.fullname == self.name.fullname:
self.referenced = True

@property
def was_referenced(self) -> bool:
return self.referenced


def check_stmts(stmts: list[Statement], errors: list[Error]) -> None:
last = ""
visitors: list[NameReferenceVisitor] = []

for stmt in stmts:
for visitor in visitors:
visitor.accept(stmt)
# No need to track referenced variables anymore
visitors = [visitor for visitor in visitors if not visitor.referenced]

match stmt:
case AssignmentStmt(lvalues=[NameExpr(name=name)], rvalue=rvalue):
if last and check_call(rvalue, name=last):
if f"{last}'" == name:
errors.append(
ErrorInfo.from_node(
stmt,
"Assignment statement should be chained",
)
)
else:
# We need to ensure that the variable is not referenced somewhere else
name_expr = NameExpr(name=last)
name_expr.fullname = last
visitors.append(NameReferenceVisitor(name_expr, stmt))

last = name if name != "_" else ""
case ReturnStmt(expr=rvalue):
if last and rvalue is not None and check_call(rvalue, name=last):
errors.append(
ErrorInfo.from_node(
stmt,
"Return statement should be chained",
)
)
case _:
last = ""

# Ensure that variables are not referenced
errors.extend(
[
ErrorInfo.from_node(
visitor.stmt,
"Assignment statement should be chained",
)
for visitor in visitors
if not visitor.referenced and visitor.stmt is not None
]
)
163 changes: 163 additions & 0 deletions test/data/err_184.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
class torch:
@staticmethod
def ones(*args):
return torch

@staticmethod
def long():
return torch

@staticmethod
def to(device: str):
return torch.Tensor()

class Tensor:
pass


def transform(x):
return x


class spark:
class read:
@staticmethod
def parquet(file_name: str):
return spark.DataFrame()

class functions:
@staticmethod
def lit(constant):
return constant

@staticmethod
def col(col_name):
return col_name

class DataFrame:
@staticmethod
def withColumnRenamed(col_in, col_out):
return spark.DataFrame()

@staticmethod
def withColumn(col_in, col_out):
return spark.DataFrame()

@staticmethod
def select(*args):
return spark.DataFrame()

class F:
@staticmethod
def lit(value):
return value


# these will match
def get_tensors(device: str) -> torch.Tensor:
a = torch.ones(2, 1)
a = a.long()
a = a.to(device)
return a


def process(file_name: str):
common_columns = ["col1_renamed", "col2_renamed", "custom_col"]
df = spark.read.parquet(file_name)
df = df \
.withColumnRenamed('col1', 'col1_renamed') \
.withColumnRenamed('col2', 'col2_renamed')
df = df \
.select(common_columns) \
.withColumn('service_type', spark.functions.lit('green'))
return df


def projection(df_in: spark.DataFrame) -> spark.DataFrame:
df = (
df_in.select(["col1", "col2"])
.withColumnRenamed("col1", "col1a")
)
return df.withColumn("col2a", spark.functions.col("col2").cast("date"))


def assign_multiple(df):
df = df.select("column")
result_df = df.select("another_column")
final_df = result_df.withColumn("column2", F.lit("abc"))
return final_df


# not yet supported
def assign_alternating(df, df2):
df = df.select("column")
df2 = df2.select("another_column")
df = df.withColumn("column2", F.lit("abc"))
return df, df2


# these will not
def ignored(x):
_ = x.op1()
_ = _.op2()
return _

def _(x):
y = x.m()
return y.operation(*[v for v in y])


def assign_multiple_referenced(df, df2):
df = df.select("column")
result_df = df.select("another_column")
return df, result_df


def invalid(df_in: spark.DataFrame, alternative_df: spark.DataFrame) -> spark.DataFrame:
df = (
df_in.select(["col1", "col2"])
.withColumnRenamed("col1", "col1a")
)
return alternative_df.withColumn("col2a", spark.functions.col("col2").cast("date"))


def no_match():
y = 10
y = transform(y)
return y

def f(x):
if x:
name = "alice"
stripped = name.strip()
print(stripped)
else:
name = "bob"
print(name)

def g(x):
try:
name = "alice"
stripped = name.strip()
print(stripped)
except ValueError:
name = "bob"
print(name)

def h(x):
for _ in (1, 2, 3):
name = "alice"
stripped = name.strip()
print(stripped)
else:
name = "bob"
print(name)

def assign_multiple_try(df):
try:
df = df.select("column")
result_df = df.select("another_column")
final_df = result_df.withColumn("column2", F.lit("abc"))
return final_df
except ValueError:
return None
7 changes: 7 additions & 0 deletions test/data/err_184.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
test/data/err_184.py:59:5 [FURB184]: Assignment statement should be chained
test/data/err_184.py:60:5 [FURB184]: Assignment statement should be chained
test/data/err_184.py:67:5 [FURB184]: Assignment statement should be chained
test/data/err_184.py:70:5 [FURB184]: Assignment statement should be chained
test/data/err_184.py:81:5 [FURB184]: Return statement should be chained
test/data/err_184.py:86:5 [FURB184]: Assignment statement should be chained
test/data/err_184.py:87:5 [FURB184]: Assignment statement should be chained