From 6090977e2f83faff1803d47e8cc82ee7813d1207 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Sat, 9 Dec 2023 22:34:11 +0100 Subject: [PATCH] support return statements and non-matching names --- refurb/checks/readability/fluid_interface.py | 30 +++++++++++++------- test/data/err_184.py | 28 ++++++++++++++++++ test/data/err_184.txt | 11 ++++--- 3 files changed, 54 insertions(+), 15 deletions(-) diff --git a/refurb/checks/readability/fluid_interface.py b/refurb/checks/readability/fluid_interface.py index 338a0b5..4a46ea0 100644 --- a/refurb/checks/readability/fluid_interface.py +++ b/refurb/checks/readability/fluid_interface.py @@ -8,6 +8,7 @@ CallExpr, MemberExpr, NameExpr, + ReturnStmt, ) from refurb.checks.common import check_block_like @@ -24,10 +25,10 @@ class ErrorInfo(Error): ```python def get_tensors(device: str) -> torch.Tensor: - a = torch.ones(2, 1) - a = a.long() - a = a.to(device) - return a + t1 = torch.ones(2, 1) + t2 = t1.long() + t3 = t2.to(device) + return t3 def process(file_name: str): @@ -46,12 +47,12 @@ def process(file_name: str): ```python def get_tensors(device: str) -> torch.Tensor: - a = ( + t3 = ( torch.ones(2, 1) .long() .to(device) ) - return a + return t3 def process(file_name: str): common_columns = ["col1_renamed", "col2_renamed", "custom_col"] @@ -75,11 +76,11 @@ def check(node: Block | MypyFile, errors: list[Error]) -> None: check_block_like(check_stmts, node, errors) -def check_call(node) -> bool: +def check_call(node, name: str | None = None) -> bool: match node: # Single chain case CallExpr(callee=MemberExpr(expr=NameExpr(name=x), name=y)): - return True + return name is None or name == x # Nested case CallExpr(callee=MemberExpr(expr=call_node, name=y)): return check_call(call_node) @@ -93,15 +94,22 @@ def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: for stmt in stmts: match stmt: case AssignmentStmt(lvalues=[NameExpr(name=name)], rvalue=rvalue): - if last and f"{last}'" == name and check_call(rvalue): + if last and check_call(rvalue, name=last): errors.append( ErrorInfo.from_node( stmt, - f"Assignment statements should be chained", + f"Assignment statement should be chained", ) ) last = name - + case ReturnStmt(expr=rvalue): + if last and check_call(rvalue, name=last): + errors.append( + ErrorInfo.from_node( + stmt, + f"Return statement should be chained", + ) + ) case _: last = "" diff --git a/test/data/err_184.py b/test/data/err_184.py index 825bc17..6064de3 100644 --- a/test/data/err_184.py +++ b/test/data/err_184.py @@ -47,6 +47,12 @@ def withColumn(col_in, col_out): def select(*args): return spark.DataFrame() +class F: + @staticmethod + def lit(value): + return value + + # these will match def get_tensors(device: str) -> torch.Tensor: a = torch.ones(2, 1) @@ -75,7 +81,29 @@ def projection(df_in: spark.DataFrame) -> spark.DataFrame: return df.withColumn("col2a", spark.functions.col("col2").cast("date")) +def assign_multiple(df, df2): + df = df.select("column") + result_df = df.select("another_column") + final_df = result_df.withColumn("column2", F.lit("abc")) + return final_df + + # these will not +def assign_alternating(df, df2): + df = df.select("column") + df2 = df2.select("another_column") + df = df.withColumn("column2", F.lit("abc")) + return df, df2 + + +def invalid(df_in: spark.DataFrame, alternative_df: spark.DataFrame) -> spark.DataFrame: + df = ( + df_in.select(["col1", "col2"]) + .withColumnRenamed("col1", "col1a") + ) + return alternative_df.withColumn("col2a", spark.functions.col("col2").cast("date")) + + def no_match(): y = 10 y = transform(y) diff --git a/test/data/err_184.txt b/test/data/err_184.txt index 51e0489..7cd32b5 100644 --- a/test/data/err_184.txt +++ b/test/data/err_184.txt @@ -1,4 +1,7 @@ -test/data/err_184.py:53:5 [FURB184]: Assignment statements should be chained -test/data/err_184.py:54:5 [FURB184]: Assignment statements should be chained -test/data/err_184.py:61:5 [FURB184]: Assignment statements should be chained -test/data/err_184.py:64:5 [FURB184]: Assignment statements should be chained +test/data/err_184.py:59:5 [FURB184]: Assignment statement should be chained +test/data/err_184.py:60:5 [FURB184]: Assignment statement should be chained +test/data/err_184.py:67:5 [FURB184]: Assignment statement should be chained +test/data/err_184.py:70:5 [FURB184]: Assignment statement should be chained +test/data/err_184.py:81:5 [FURB184]: Return statement should be chained +test/data/err_184.py:86:5 [FURB184]: Assignment statement should be chained +test/data/err_184.py:87:5 [FURB184]: Assignment statement should be chained