From bb8e5f855185ef7ec1bf2070bbcf605a5f4dd60f Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Mon, 11 Sep 2023 13:14:09 +0200 Subject: [PATCH 01/11] initial implementation for fluid interface check --- refurb/checks/readability/fluid_interface.py | 107 +++++++++++++++++++ test/data/err_184.py | 82 ++++++++++++++ test/data/err_184.txt | 4 + 3 files changed, 193 insertions(+) create mode 100644 refurb/checks/readability/fluid_interface.py create mode 100644 test/data/err_184.py create mode 100644 test/data/err_184.txt diff --git a/refurb/checks/readability/fluid_interface.py b/refurb/checks/readability/fluid_interface.py new file mode 100644 index 0000000..338a0b5 --- /dev/null +++ b/refurb/checks/readability/fluid_interface.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass + +from mypy.nodes import ( + Block, + Statement, + AssignmentStmt, + MypyFile, + CallExpr, + MemberExpr, + NameExpr, +) + +from refurb.checks.common import check_block_like +from refurb.error import Error + + +@dataclass +class ErrorInfo(Error): + r"""When an API has a Fluent Interface (the ability to chain multiple calls together), you should chain those calls + instead of repeatedly assigning and using the value. + Sometimes a return statement can be written more succinctly: + + Bad: + + ```python + def get_tensors(device: str) -> torch.Tensor: + a = torch.ones(2, 1) + a = a.long() + a = a.to(device) + return a + + + def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = spark.read.parquet(file_name) + df = df \ + .withColumnRenamed('col1', 'col1_renamed') \ + .withColumnRenamed('col2', 'col2_renamed') + df = df \ + .select(common_columns) \ + .withColumn('service_type', F.lit('green')) + return df + ``` + + Good: + + ```python + def get_tensors(device: str) -> torch.Tensor: + a = ( + torch.ones(2, 1) + .long() + .to(device) + ) + return a + + def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = ( + spark.read.parquet(file_name) + .withColumnRenamed('col1', 'col1_renamed') + .withColumnRenamed('col2', 'col2_renamed') + .select(common_columns) + .withColumn('service_type', F.lit('green')) + ) + return df + ``` + """ + + name = "use-fluid-interface" + code = 184 + categories = ("readability",) + + +def check(node: Block | MypyFile, errors: list[Error]) -> None: + check_block_like(check_stmts, node, errors) + + +def check_call(node) -> bool: + match node: + # Single chain + case CallExpr(callee=MemberExpr(expr=NameExpr(name=x), name=y)): + return True + # Nested + case CallExpr(callee=MemberExpr(expr=call_node, name=y)): + return check_call(call_node) + + return False + + +def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: + last = "" + + for stmt in stmts: + match stmt: + case AssignmentStmt(lvalues=[NameExpr(name=name)], rvalue=rvalue): + if last and f"{last}'" == name and check_call(rvalue): + errors.append( + ErrorInfo.from_node( + stmt, + f"Assignment statements should be chained", + ) + ) + + last = name + + case _: + last = "" diff --git a/test/data/err_184.py b/test/data/err_184.py new file mode 100644 index 0000000..825bc17 --- /dev/null +++ b/test/data/err_184.py @@ -0,0 +1,82 @@ +class torch: + @staticmethod + def ones(*args): + return torch + + @staticmethod + def long(): + return torch + + @staticmethod + def to(device: str): + return torch.Tensor() + + class Tensor: + pass + + +def transform(x): + return x + + +class spark: + class read: + @staticmethod + def parquet(file_name: str): + return spark.DataFrame() + + class functions: + @staticmethod + def lit(constant): + return constant + + @staticmethod + def col(col_name): + return col_name + + class DataFrame: + @staticmethod + def withColumnRenamed(col_in, col_out): + return spark.DataFrame() + + @staticmethod + def withColumn(col_in, col_out): + return spark.DataFrame() + + @staticmethod + def select(*args): + return spark.DataFrame() + +# these will match +def get_tensors(device: str) -> torch.Tensor: + a = torch.ones(2, 1) + a = a.long() + a = a.to(device) + return a + + +def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = spark.read.parquet(file_name) + df = df \ + .withColumnRenamed('col1', 'col1_renamed') \ + .withColumnRenamed('col2', 'col2_renamed') + df = df \ + .select(common_columns) \ + .withColumn('service_type', spark.functions.lit('green')) + return df + + +def projection(df_in: spark.DataFrame) -> spark.DataFrame: + df = ( + df_in.select(["col1", "col2"]) + .withColumnRenamed("col1", "col1a") + ) + return df.withColumn("col2a", spark.functions.col("col2").cast("date")) + + +# these will not +def no_match(): + y = 10 + y = transform(y) + return y diff --git a/test/data/err_184.txt b/test/data/err_184.txt new file mode 100644 index 0000000..51e0489 --- /dev/null +++ b/test/data/err_184.txt @@ -0,0 +1,4 @@ +test/data/err_184.py:53:5 [FURB184]: Assignment statements should be chained +test/data/err_184.py:54:5 [FURB184]: Assignment statements should be chained +test/data/err_184.py:61:5 [FURB184]: Assignment statements should be chained +test/data/err_184.py:64:5 [FURB184]: Assignment statements should be chained From 77b6806c2b8785381bb6577bda76ac7b65628476 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Sat, 9 Dec 2023 22:34:11 +0100 Subject: [PATCH 02/11] support return statements and non-matching names --- refurb/checks/readability/fluid_interface.py | 79 ++++++++++++++++---- test/data/err_184.py | 35 +++++++++ test/data/err_184.txt | 4 - 3 files changed, 98 insertions(+), 20 deletions(-) delete mode 100644 test/data/err_184.txt diff --git a/refurb/checks/readability/fluid_interface.py b/refurb/checks/readability/fluid_interface.py index 338a0b5..ac101b3 100644 --- a/refurb/checks/readability/fluid_interface.py +++ b/refurb/checks/readability/fluid_interface.py @@ -8,10 +8,12 @@ CallExpr, MemberExpr, NameExpr, + ReturnStmt, ) from refurb.checks.common import check_block_like from refurb.error import Error +from refurb.visitor import TraverserVisitor @dataclass @@ -24,11 +26,10 @@ class ErrorInfo(Error): ```python def get_tensors(device: str) -> torch.Tensor: - a = torch.ones(2, 1) - a = a.long() - a = a.to(device) - return a - + t1 = torch.ones(2, 1) + t2 = t1.long() + t3 = t2.to(device) + return t3 def process(file_name: str): common_columns = ["col1_renamed", "col2_renamed", "custom_col"] @@ -46,12 +47,12 @@ def process(file_name: str): ```python def get_tensors(device: str) -> torch.Tensor: - a = ( + t3 = ( torch.ones(2, 1) .long() .to(device) ) - return a + return t3 def process(file_name: str): common_columns = ["col1_renamed", "col2_renamed", "custom_col"] @@ -75,33 +76,79 @@ def check(node: Block | MypyFile, errors: list[Error]) -> None: check_block_like(check_stmts, node, errors) -def check_call(node) -> bool: +def check_call(node, name: str | None = None) -> bool: match node: # Single chain - case CallExpr(callee=MemberExpr(expr=NameExpr(name=x), name=y)): - return True + case CallExpr(callee=MemberExpr(expr=NameExpr(name=x), name=_)): + return name is None or name == x # Nested - case CallExpr(callee=MemberExpr(expr=call_node, name=y)): + case CallExpr(callee=MemberExpr(expr=call_node, name=_)): return check_call(call_node) return False +class NameReferenceVisitor(TraverserVisitor): + name: NameExpr + referenced: bool + + def __init__(self, name: NameExpr, stmt: Statement) -> None: + super().__init__() + self.name = name + self.stmt = stmt + self.referenced = False + + def visit_name_expr(self, node: NameExpr) -> None: + if not self.referenced and node.fullname == self.name.fullname: + self.referenced = True + + @property + def was_referenced(self) -> bool: + return self.referenced + + def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: last = "" + visitors = [] for stmt in stmts: + for visitor in visitors: + visitor.accept(stmt) + match stmt: case AssignmentStmt(lvalues=[NameExpr(name=name)], rvalue=rvalue): - if last and f"{last}'" == name and check_call(rvalue): + if last and check_call(rvalue, name=last): + if f"{last}'" == name: + errors.append( + ErrorInfo.from_node( + stmt, + f"Assignment statement should be chained", + ) + ) + else: + # We need to ensure that the variable is not referenced somewhere else + name_expr = NameExpr(name=last) + name_expr.fullname = last + visitors.append(NameReferenceVisitor(name_expr, stmt)) + + last = name + case ReturnStmt(expr=rvalue): + if last and check_call(rvalue, name=last): errors.append( ErrorInfo.from_node( stmt, - f"Assignment statements should be chained", + f"Return statement should be chained", ) ) - - last = name - case _: last = "" + + # Ensure that variables are not referenced + for visitor in visitors: + if not visitor.referenced: + errors.append( + ErrorInfo.from_node( + visitor.stmt, + f"Assignment statement should be chained", + ) + ) diff --git a/test/data/err_184.py b/test/data/err_184.py index 825bc17..43034cd 100644 --- a/test/data/err_184.py +++ b/test/data/err_184.py @@ -47,6 +47,12 @@ def withColumn(col_in, col_out): def select(*args): return spark.DataFrame() +class F: + @staticmethod + def lit(value): + return value + + # these will match def get_tensors(device: str) -> torch.Tensor: a = torch.ones(2, 1) @@ -75,7 +81,36 @@ def projection(df_in: spark.DataFrame) -> spark.DataFrame: return df.withColumn("col2a", spark.functions.col("col2").cast("date")) +def assign_multiple(df): + df = df.select("column") + result_df = df.select("another_column") + final_df = result_df.withColumn("column2", F.lit("abc")) + return final_df + + +# not yet supported +def assign_alternating(df, df2): + df = df.select("column") + df2 = df2.select("another_column") + df = df.withColumn("column2", F.lit("abc")) + return df, df2 + + # these will not +def assign_multiple_referenced(df, df2): + df = df.select("column") + result_df = df.select("another_column") + return df, result_df + + +def invalid(df_in: spark.DataFrame, alternative_df: spark.DataFrame) -> spark.DataFrame: + df = ( + df_in.select(["col1", "col2"]) + .withColumnRenamed("col1", "col1a") + ) + return alternative_df.withColumn("col2a", spark.functions.col("col2").cast("date")) + + def no_match(): y = 10 y = transform(y) diff --git a/test/data/err_184.txt b/test/data/err_184.txt deleted file mode 100644 index 51e0489..0000000 --- a/test/data/err_184.txt +++ /dev/null @@ -1,4 +0,0 @@ -test/data/err_184.py:53:5 [FURB184]: Assignment statements should be chained -test/data/err_184.py:54:5 [FURB184]: Assignment statements should be chained -test/data/err_184.py:61:5 [FURB184]: Assignment statements should be chained -test/data/err_184.py:64:5 [FURB184]: Assignment statements should be chained From 554629f66201b8db5635def4b8de6de028351b0d Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Sun, 10 Dec 2023 00:06:19 +0100 Subject: [PATCH 03/11] performance optimization --- refurb/checks/readability/fluid_interface.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/refurb/checks/readability/fluid_interface.py b/refurb/checks/readability/fluid_interface.py index ac101b3..a1ac8a1 100644 --- a/refurb/checks/readability/fluid_interface.py +++ b/refurb/checks/readability/fluid_interface.py @@ -114,6 +114,8 @@ def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: for stmt in stmts: for visitor in visitors: visitor.accept(stmt) + # No need to track referenced variables anymore + visitors = [visitor for visitor in visitors if not visitor.referenced] match stmt: case AssignmentStmt(lvalues=[NameExpr(name=name)], rvalue=rvalue): From 19012e85c1af05a781db261254f36806c47c658b Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Sun, 10 Dec 2023 00:40:37 +0100 Subject: [PATCH 04/11] test results --- test/data/err_184.txt | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 test/data/err_184.txt diff --git a/test/data/err_184.txt b/test/data/err_184.txt new file mode 100644 index 0000000..7cd32b5 --- /dev/null +++ b/test/data/err_184.txt @@ -0,0 +1,7 @@ +test/data/err_184.py:59:5 [FURB184]: Assignment statement should be chained +test/data/err_184.py:60:5 [FURB184]: Assignment statement should be chained +test/data/err_184.py:67:5 [FURB184]: Assignment statement should be chained +test/data/err_184.py:70:5 [FURB184]: Assignment statement should be chained +test/data/err_184.py:81:5 [FURB184]: Return statement should be chained +test/data/err_184.py:86:5 [FURB184]: Assignment statement should be chained +test/data/err_184.py:87:5 [FURB184]: Assignment statement should be chained From 68dda8aeb06f8e5b7316403f5c5224c26b77ff02 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Sun, 10 Dec 2023 01:13:28 +0100 Subject: [PATCH 05/11] exclude other references --- refurb/checks/readability/fluid_interface.py | 16 ++++++++++++---- test/data/err_184.py | 5 +++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/refurb/checks/readability/fluid_interface.py b/refurb/checks/readability/fluid_interface.py index a1ac8a1..f0896e5 100644 --- a/refurb/checks/readability/fluid_interface.py +++ b/refurb/checks/readability/fluid_interface.py @@ -11,7 +11,7 @@ ReturnStmt, ) -from refurb.checks.common import check_block_like +from refurb.checks.common import check_block_like, ReadCountVisitor from refurb.error import Error from refurb.visitor import TraverserVisitor @@ -80,9 +80,17 @@ def check_call(node, name: str | None = None) -> bool: match node: # Single chain case CallExpr(callee=MemberExpr(expr=NameExpr(name=x), name=_)): - return name is None or name == x + if name is None or name == x: + # Exclude other references + x_expr = NameExpr(x) + x_expr.fullname = x + visitor = ReadCountVisitor(x_expr) + visitor.accept(node) + return visitor.read_count == 1 + return False + # Nested - case CallExpr(callee=MemberExpr(expr=call_node, name=_)): + case CallExpr(callee=MemberExpr(expr=call_node, name=y)): return check_call(call_node) return False @@ -92,7 +100,7 @@ class NameReferenceVisitor(TraverserVisitor): name: NameExpr referenced: bool - def __init__(self, name: NameExpr, stmt: Statement) -> None: + def __init__(self, name: NameExpr, stmt: Statement | None = None) -> None: super().__init__() self.name = name self.stmt = stmt diff --git a/test/data/err_184.py b/test/data/err_184.py index 43034cd..f9dc580 100644 --- a/test/data/err_184.py +++ b/test/data/err_184.py @@ -97,6 +97,11 @@ def assign_alternating(df, df2): # these will not +def _(x): + y = x.m() + return y.operation(*[v for v in y]) + + def assign_multiple_referenced(df, df2): df = df.select("column") result_df = df.select("another_column") From af3300641271f4b46ae6200076c6f3d2175dedce Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Mon, 11 Dec 2023 10:36:51 +0100 Subject: [PATCH 06/11] fix tests and exclude `_` --- refurb/checks/readability/fluid_interface.py | 4 ++-- test/data/err_184.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/refurb/checks/readability/fluid_interface.py b/refurb/checks/readability/fluid_interface.py index f0896e5..835fdaa 100644 --- a/refurb/checks/readability/fluid_interface.py +++ b/refurb/checks/readability/fluid_interface.py @@ -91,7 +91,7 @@ def check_call(node, name: str | None = None) -> bool: # Nested case CallExpr(callee=MemberExpr(expr=call_node, name=y)): - return check_call(call_node) + return check_call(call_node, name=name) return False @@ -141,7 +141,7 @@ def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: name_expr.fullname = last visitors.append(NameReferenceVisitor(name_expr, stmt)) - last = name + last = name if name != "_" else "" case ReturnStmt(expr=rvalue): if last and check_call(rvalue, name=last): errors.append( diff --git a/test/data/err_184.py b/test/data/err_184.py index f9dc580..58f37b3 100644 --- a/test/data/err_184.py +++ b/test/data/err_184.py @@ -97,6 +97,11 @@ def assign_alternating(df, df2): # these will not +def ignored(x): + _ = x.op1() + _ = _.op2() + return _ + def _(x): y = x.m() return y.operation(*[v for v in y]) From 97d71fceb378e71f97b2e23c309b48c5787cb0fc Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Tue, 19 Dec 2023 11:20:11 +0100 Subject: [PATCH 07/11] fix linting --- refurb/checks/readability/fluid_interface.py | 37 +++++++++++--------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/refurb/checks/readability/fluid_interface.py b/refurb/checks/readability/fluid_interface.py index 835fdaa..184ed92 100644 --- a/refurb/checks/readability/fluid_interface.py +++ b/refurb/checks/readability/fluid_interface.py @@ -1,25 +1,26 @@ from dataclasses import dataclass from mypy.nodes import ( - Block, - Statement, AssignmentStmt, - MypyFile, + Block, CallExpr, MemberExpr, + MypyFile, NameExpr, ReturnStmt, + Statement, ) -from refurb.checks.common import check_block_like, ReadCountVisitor +from refurb.checks.common import ReadCountVisitor, check_block_like from refurb.error import Error from refurb.visitor import TraverserVisitor @dataclass class ErrorInfo(Error): - r"""When an API has a Fluent Interface (the ability to chain multiple calls together), you should chain those calls - instead of repeatedly assigning and using the value. + r""" + When an API has a Fluent Interface (the ability to chain multiple calls together), you should + chain those calls instead of repeatedly assigning and using the value. Sometimes a return statement can be written more succinctly: Bad: @@ -76,7 +77,7 @@ def check(node: Block | MypyFile, errors: list[Error]) -> None: check_block_like(check_stmts, node, errors) -def check_call(node, name: str | None = None) -> bool: +def check_call(node: CallExpr, name: str | None = None) -> bool: match node: # Single chain case CallExpr(callee=MemberExpr(expr=NameExpr(name=x), name=_)): @@ -90,7 +91,7 @@ def check_call(node, name: str | None = None) -> bool: return False # Nested - case CallExpr(callee=MemberExpr(expr=call_node, name=y)): + case CallExpr(callee=MemberExpr(expr=call_node, name=_)): return check_call(call_node, name=name) return False @@ -132,7 +133,7 @@ def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: errors.append( ErrorInfo.from_node( stmt, - f"Assignment statement should be chained", + "Assignment statement should be chained", ) ) else: @@ -147,18 +148,20 @@ def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: errors.append( ErrorInfo.from_node( stmt, - f"Return statement should be chained", + "Return statement should be chained", ) ) case _: last = "" # Ensure that variables are not referenced - for visitor in visitors: - if not visitor.referenced: - errors.append( - ErrorInfo.from_node( - visitor.stmt, - f"Assignment statement should be chained", - ) + errors.extend( + [ + ErrorInfo.from_node( + visitor.stmt, + "Assignment statement should be chained", ) + for visitor in visitors + if not visitor.referenced + ] + ) From d527c6b354fec60432141f1ba6bb161d42994dba Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Tue, 19 Dec 2023 11:23:03 +0100 Subject: [PATCH 08/11] fix mypy --- refurb/checks/readability/fluid_interface.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/refurb/checks/readability/fluid_interface.py b/refurb/checks/readability/fluid_interface.py index 184ed92..7850442 100644 --- a/refurb/checks/readability/fluid_interface.py +++ b/refurb/checks/readability/fluid_interface.py @@ -4,6 +4,7 @@ AssignmentStmt, Block, CallExpr, + Expression, MemberExpr, MypyFile, NameExpr, @@ -77,7 +78,7 @@ def check(node: Block | MypyFile, errors: list[Error]) -> None: check_block_like(check_stmts, node, errors) -def check_call(node: CallExpr, name: str | None = None) -> bool: +def check_call(node: Expression, name: str | None = None) -> bool: match node: # Single chain case CallExpr(callee=MemberExpr(expr=NameExpr(name=x), name=_)): @@ -118,7 +119,7 @@ def was_referenced(self) -> bool: def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: last = "" - visitors = [] + visitors: list[NameReferenceVisitor] = [] for stmt in stmts: for visitor in visitors: @@ -144,7 +145,7 @@ def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: last = name if name != "_" else "" case ReturnStmt(expr=rvalue): - if last and check_call(rvalue, name=last): + if last and rvalue is not None and check_call(rvalue, name=last): errors.append( ErrorInfo.from_node( stmt, @@ -162,6 +163,6 @@ def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: "Assignment statement should be chained", ) for visitor in visitors - if not visitor.referenced + if not visitor.referenced and visitor.stmt is not None ] ) From 09ce1493a78fc4c9ddad16cf451af741d45fe8be Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Tue, 19 Dec 2023 11:25:01 +0100 Subject: [PATCH 09/11] add suggested test case --- test/data/err_184.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/data/err_184.py b/test/data/err_184.py index 58f37b3..fe8e495 100644 --- a/test/data/err_184.py +++ b/test/data/err_184.py @@ -125,3 +125,12 @@ def no_match(): y = 10 y = transform(y) return y + +def f(x): + if x: + name = "alice" + stripped = name.strip() + print(stripped) + else: + name = "bob" + print(name) From 93e0d46e89b873f1d9c31f48cba1127579e1ea54 Mon Sep 17 00:00:00 2001 From: Simon Brugman Date: Tue, 19 Dec 2023 12:08:23 +0100 Subject: [PATCH 10/11] restrict to top-level function definitions --- refurb/checks/readability/fluid_interface.py | 7 +++-- test/data/err_184.py | 27 ++++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/refurb/checks/readability/fluid_interface.py b/refurb/checks/readability/fluid_interface.py index 7850442..f2d8f34 100644 --- a/refurb/checks/readability/fluid_interface.py +++ b/refurb/checks/readability/fluid_interface.py @@ -2,11 +2,10 @@ from mypy.nodes import ( AssignmentStmt, - Block, CallExpr, Expression, + FuncDef, MemberExpr, - MypyFile, NameExpr, ReturnStmt, Statement, @@ -74,8 +73,8 @@ def process(file_name: str): categories = ("readability",) -def check(node: Block | MypyFile, errors: list[Error]) -> None: - check_block_like(check_stmts, node, errors) +def check(node: FuncDef, errors: list[Error]) -> None: + check_block_like(check_stmts, node.body, errors) def check_call(node: Expression, name: str | None = None) -> bool: diff --git a/test/data/err_184.py b/test/data/err_184.py index fe8e495..f34cb92 100644 --- a/test/data/err_184.py +++ b/test/data/err_184.py @@ -134,3 +134,30 @@ def f(x): else: name = "bob" print(name) + +def g(x): + try: + name = "alice" + stripped = name.strip() + print(stripped) + except ValueError: + name = "bob" + print(name) + +def h(x): + for _ in (1, 2, 3): + name = "alice" + stripped = name.strip() + print(stripped) + else: + name = "bob" + print(name) + +def assign_multiple_try(df): + try: + df = df.select("column") + result_df = df.select("another_column") + final_df = result_df.withColumn("column2", F.lit("abc")) + return final_df + except ValueError: + return None From 227299e08675a738c77232675b3d45a71776868e Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 21 Dec 2023 22:58:09 -0800 Subject: [PATCH 11/11] Remove `was_referenced` property, run `make docs` --- docs/checks.md | 52 ++++++++++++++++++++ refurb/checks/readability/fluid_interface.py | 4 -- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/docs/checks.md b/docs/checks.md index c3f4668..f364b1e 100644 --- a/docs/checks.md +++ b/docs/checks.md @@ -2180,4 +2180,56 @@ Good: ```python nums = [123, 456] num = str(num[0]) +``` + +## FURB184: `use-fluid-interface` + +Categories: `readability` + +When an API has a Fluent Interface (the ability to chain multiple calls together), you should +chain those calls instead of repeatedly assigning and using the value. +Sometimes a return statement can be written more succinctly: + +Bad: + +```pythonpython +def get_tensors(device: str) -> torch.Tensor: + t1 = torch.ones(2, 1) + t2 = t1.long() + t3 = t2.to(device) + return t3 + +def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = spark.read.parquet(file_name) + df = df \ + .withColumnRenamed('col1', 'col1_renamed') \ + .withColumnRenamed('col2', 'col2_renamed') + df = df \ + .select(common_columns) \ + .withColumn('service_type', F.lit('green')) + return df +``` + +Good: + +```pythonpython +def get_tensors(device: str) -> torch.Tensor: + t3 = ( + torch.ones(2, 1) + .long() + .to(device) + ) + return t3 + +def process(file_name: str): + common_columns = ["col1_renamed", "col2_renamed", "custom_col"] + df = ( + spark.read.parquet(file_name) + .withColumnRenamed('col1', 'col1_renamed') + .withColumnRenamed('col2', 'col2_renamed') + .select(common_columns) + .withColumn('service_type', F.lit('green')) + ) + return df ``` \ No newline at end of file diff --git a/refurb/checks/readability/fluid_interface.py b/refurb/checks/readability/fluid_interface.py index f2d8f34..aa9a45a 100644 --- a/refurb/checks/readability/fluid_interface.py +++ b/refurb/checks/readability/fluid_interface.py @@ -111,10 +111,6 @@ def visit_name_expr(self, node: NameExpr) -> None: if not self.referenced and node.fullname == self.name.fullname: self.referenced = True - @property - def was_referenced(self) -> bool: - return self.referenced - def check_stmts(stmts: list[Statement], errors: list[Error]) -> None: last = ""