diff --git a/docs/checks.md b/docs/checks.md index c3f4668..f364b1e 100644 --- a/docs/checks.md +++ b/docs/checks.md @@ -2180,4 +2180,56 @@ Good: ```python nums = [123, 456] num = str(num[0]) +``` + +## FURB184: `use-fluid-interface` + +Categories: `readability` + +When an API has a Fluent Interface (the ability to chain multiple calls together), you should +chain those calls instead of repeatedly assigning and using the value. +Sometimes a return statement can be written more succinctly: + +Bad: + +```pythonpython +def get_tensors(device: str) -> torch.Tensor: + t1 = torch.ones(2, 1) + t2 = t1.long() + t3 = t2.to(device) + return t3 + +def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = spark.read.parquet(file_name) + df = df \ + .withColumnRenamed('col1', 'col1_renamed') \ + .withColumnRenamed('col2', 'col2_renamed') + df = df \ + .select(common_columns) \ + .withColumn('service_type', F.lit('green')) + return df +``` + +Good: + +```pythonpython +def get_tensors(device: str) -> torch.Tensor: + t3 = ( + torch.ones(2, 1) + .long() + .to(device) + ) + return t3 + +def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = ( + spark.read.parquet(file_name) + .withColumnRenamed('col1', 'col1_renamed') + .withColumnRenamed('col2', 'col2_renamed') + .select(common_columns) + .withColumn('service_type', F.lit('green')) + ) + return df ``` \ No newline at end of file diff --git a/refurb/checks/readability/fluid_interface.py b/refurb/checks/readability/fluid_interface.py new file mode 100644 index 0000000..aa9a45a --- /dev/null +++ b/refurb/checks/readability/fluid_interface.py @@ -0,0 +1,163 @@ +from dataclasses import dataclass + +from mypy.nodes import ( + AssignmentStmt, + CallExpr, + Expression, + FuncDef, + MemberExpr, + NameExpr, + ReturnStmt, + Statement, +) + +from refurb.checks.common import ReadCountVisitor, check_block_like +from refurb.error import Error +from refurb.visitor import TraverserVisitor + + +@dataclass +class ErrorInfo(Error): + r""" + When an API has a Fluent Interface (the ability to chain multiple calls together), you should + chain those calls instead of repeatedly assigning and using the value. + Sometimes a return statement can be written more succinctly: + + Bad: + + ```python + def get_tensors(device: str) -> torch.Tensor: + t1 = torch.ones(2, 1) + t2 = t1.long() + t3 = t2.to(device) + return t3 + + def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = spark.read.parquet(file_name) + df = df \ + .withColumnRenamed('col1', 'col1_renamed') \ + .withColumnRenamed('col2', 'col2_renamed') + df = df \ + .select(common_columns) \ + .withColumn('service_type', F.lit('green')) + return df + ``` + + Good: + + ```python + def get_tensors(device: str) -> torch.Tensor: + t3 = ( + torch.ones(2, 1) + .long() + .to(device) + ) + return t3 + + def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = ( + spark.read.parquet(file_name) + .withColumnRenamed('col1', 'col1_renamed') + .withColumnRenamed('col2', 'col2_renamed') + .select(common_columns) + .withColumn('service_type', F.lit('green')) + ) + return df + ``` + """ + + name = "use-fluid-interface" + code = 184 + categories = ("readability",) + + +def check(node: FuncDef, errors: list[Error]) -> None: + check_block_like(check_stmts, node.body, errors) + + +def check_call(node: Expression, name: str | None = None) -> bool: + match node: + # Single chain + case CallExpr(callee=MemberExpr(expr=NameExpr(name=x), name=_)): + if name is None or name == x: + # Exclude other references + x_expr = NameExpr(x) + x_expr.fullname = x + visitor = ReadCountVisitor(x_expr) + visitor.accept(node) + return visitor.read_count == 1 + return False + + # Nested + case CallExpr(callee=MemberExpr(expr=call_node, name=_)): + return check_call(call_node, name=name) + + return False + + +class NameReferenceVisitor(TraverserVisitor): + name: NameExpr + referenced: bool + + def __init__(self, name: NameExpr, stmt: Statement | None = None) -> None: + super().__init__() + self.name = name + self.stmt = stmt + self.referenced = False + + def visit_name_expr(self, node: NameExpr) -> None: + if not self.referenced and node.fullname == self.name.fullname: + self.referenced = True + + +def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: + last = "" + visitors: list[NameReferenceVisitor] = [] + + for stmt in stmts: + for visitor in visitors: + visitor.accept(stmt) + # No need to track referenced variables anymore + visitors = [visitor for visitor in visitors if not visitor.referenced] + + match stmt: + case AssignmentStmt(lvalues=[NameExpr(name=name)], rvalue=rvalue): + if last and check_call(rvalue, name=last): + if f"{last}'" == name: + errors.append( + ErrorInfo.from_node( + stmt, + "Assignment statement should be chained", + ) + ) + else: + # We need to ensure that the variable is not referenced somewhere else + name_expr = NameExpr(name=last) + name_expr.fullname = last + visitors.append(NameReferenceVisitor(name_expr, stmt)) + + last = name if name != "_" else "" + case ReturnStmt(expr=rvalue): + if last and rvalue is not None and check_call(rvalue, name=last): + errors.append( + ErrorInfo.from_node( + stmt, + "Return statement should be chained", + ) + ) + case _: + last = "" + + # Ensure that variables are not referenced + errors.extend( + [ + ErrorInfo.from_node( + visitor.stmt, + "Assignment statement should be chained", + ) + for visitor in visitors + if not visitor.referenced and visitor.stmt is not None + ] + ) diff --git a/test/data/err_184.py b/test/data/err_184.py new file mode 100644 index 0000000..f34cb92 --- /dev/null +++ b/test/data/err_184.py @@ -0,0 +1,163 @@ +class torch: + @staticmethod + def ones(*args): + return torch + + @staticmethod + def long(): + return torch + + @staticmethod + def to(device: str): + return torch.Tensor() + + class Tensor: + pass + + +def transform(x): + return x + + +class spark: + class read: + @staticmethod + def parquet(file_name: str): + return spark.DataFrame() + + class functions: + @staticmethod + def lit(constant): + return constant + + @staticmethod + def col(col_name): + return col_name + + class DataFrame: + @staticmethod + def withColumnRenamed(col_in, col_out): + return spark.DataFrame() + + @staticmethod + def withColumn(col_in, col_out): + return spark.DataFrame() + + @staticmethod + def select(*args): + return spark.DataFrame() + +class F: + @staticmethod + def lit(value): + return value + + +# these will match +def get_tensors(device: str) -> torch.Tensor: + a = torch.ones(2, 1) + a = a.long() + a = a.to(device) + return a + + +def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = spark.read.parquet(file_name) + df = df \ + .withColumnRenamed('col1', 'col1_renamed') \ + .withColumnRenamed('col2', 'col2_renamed') + df = df \ + .select(common_columns) \ + .withColumn('service_type', spark.functions.lit('green')) + return df + + +def projection(df_in: spark.DataFrame) -> spark.DataFrame: + df = ( + df_in.select(["col1", "col2"]) + .withColumnRenamed("col1", "col1a") + ) + return df.withColumn("col2a", spark.functions.col("col2").cast("date")) + + +def assign_multiple(df): + df = df.select("column") + result_df = df.select("another_column") + final_df = result_df.withColumn("column2", F.lit("abc")) + return final_df + + +# not yet supported +def assign_alternating(df, df2): + df = df.select("column") + df2 = df2.select("another_column") + df = df.withColumn("column2", F.lit("abc")) + return df, df2 + + +# these will not +def ignored(x): + _ = x.op1() + _ = _.op2() + return _ + +def _(x): + y = x.m() + return y.operation(*[v for v in y]) + + +def assign_multiple_referenced(df, df2): + df = df.select("column") + result_df = df.select("another_column") + return df, result_df + + +def invalid(df_in: spark.DataFrame, alternative_df: spark.DataFrame) -> spark.DataFrame: + df = ( + df_in.select(["col1", "col2"]) + .withColumnRenamed("col1", "col1a") + ) + return alternative_df.withColumn("col2a", spark.functions.col("col2").cast("date")) + + +def no_match(): + y = 10 + y = transform(y) + return y + +def f(x): + if x: + name = "alice" + stripped = name.strip() + print(stripped) + else: + name = "bob" + print(name) + +def g(x): + try: + name = "alice" + stripped = name.strip() + print(stripped) + except ValueError: + name = "bob" + print(name) + +def h(x): + for _ in (1, 2, 3): + name = "alice" + stripped = name.strip() + print(stripped) + else: + name = "bob" + print(name) + +def assign_multiple_try(df): + try: + df = df.select("column") + result_df = df.select("another_column") + final_df = result_df.withColumn("column2", F.lit("abc")) + return final_df + except ValueError: + return None diff --git a/test/data/err_184.txt b/test/data/err_184.txt new file mode 100644 index 0000000..7cd32b5 --- /dev/null +++ b/test/data/err_184.txt @@ -0,0 +1,7 @@ +test/data/err_184.py:59:5 [FURB184]: Assignment statement should be chained +test/data/err_184.py:60:5 [FURB184]: Assignment statement should be chained +test/data/err_184.py:67:5 [FURB184]: Assignment statement should be chained +test/data/err_184.py:70:5 [FURB184]: Assignment statement should be chained +test/data/err_184.py:81:5 [FURB184]: Return statement should be chained +test/data/err_184.py:86:5 [FURB184]: Assignment statement should be chained +test/data/err_184.py:87:5 [FURB184]: Assignment statement should be chained