diff --git a/refurb/checks/common.py b/refurb/checks/common.py index 96257f9..39b8494 100644 --- a/refurb/checks/common.py +++ b/refurb/checks/common.py @@ -335,4 +335,9 @@ def _stringify(node: Node) -> str: return f"lambda{args}: {body}" + case ListExpr(items=items): + inner = ", ".join(stringify(x) for x in items) + + return f"[{inner}]" + raise ValueError diff --git a/refurb/checks/itertools/use_chain_from_iterable.py b/refurb/checks/itertools/use_chain_from_iterable.py index f03f268..aff47d7 100644 --- a/refurb/checks/itertools/use_chain_from_iterable.py +++ b/refurb/checks/itertools/use_chain_from_iterable.py @@ -8,8 +8,10 @@ ListExpr, NameExpr, RefExpr, + SetComprehension, ) +from refurb.checks.common import stringify from refurb.error import Error @@ -81,7 +83,7 @@ def is_flatten_generator(node: GeneratorExpr) -> bool: def check( - node: ListComprehension | GeneratorExpr | CallExpr, + node: ListComprehension | SetComprehension | GeneratorExpr | CallExpr, errors: list[Error], ) -> None: if id(node) in ignore: @@ -92,9 +94,11 @@ def check( old = "[... for ... in x for ... in ...]" new = "list(chain.from_iterable(x))" - msg = f"Replace `{old}` with `{new}`" + ignore.add(id(g)) - errors.append(ErrorInfo.from_node(node, msg)) + case SetComprehension(generator=g) if is_flatten_generator(g): + old = "{... for ... in x for ... in ...}" + new = "set(chain.from_iterable(x))" ignore.add(id(g)) @@ -102,31 +106,39 @@ def check( old = "... for ... in x for ... in ..." new = "chain.from_iterable(x)" - msg = f"Replace `{old}` with `{new}`" - - errors.append(ErrorInfo.from_node(node, msg)) - case CallExpr( callee=RefExpr(fullname="builtins.sum"), - args=[_, ListExpr(items=[])], + args=[arg, ListExpr(items=[])], ): - old = "sum(x, [])" - new = "chain.from_iterable(x)" + old = f"sum({stringify(arg)}, [])" + new = f"chain.from_iterable({stringify(arg)})" + + case CallExpr( + callee=RefExpr(fullname="functools.reduce"), + args=[op, arg] | [op, arg, ListExpr(items=[])], + ): + match op: + case RefExpr(fullname="_operator.add" | "_operator.concat"): + pass - msg = f"Replace `{old}` with `{new}`" + case _: + return - errors.append(ErrorInfo.from_node(node, msg)) + old = stringify(node) + new = f"chain.from_iterable({stringify(arg)})" case CallExpr( callee=RefExpr(fullname="itertools.chain") as callee, - args=[_], + args=[arg], arg_kinds=[ArgKind.ARG_STAR], ): chain = "chain" if isinstance(callee, NameExpr) else "itertools.chain" - old = f"{chain}(*x)" - new = f"{chain}.from_iterable(x)" + old = f"{chain}(*{stringify(arg)})" + new = f"{chain}.from_iterable({stringify(arg)})" - msg = f"Replace `{old}` with `{new}`" + case _: + return - errors.append(ErrorInfo.from_node(node, msg)) + msg = f"Replace `{old}` with `{new}`" + errors.append(ErrorInfo.from_node(node, msg)) diff --git a/test/data/err_179.py b/test/data/err_179.py index 72700e9..a1c1c44 100644 --- a/test/data/err_179.py +++ b/test/data/err_179.py @@ -1,5 +1,9 @@ +from functools import reduce +from operator import add, concat, iadd from itertools import chain +import functools import itertools +import operator rows = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] @@ -30,6 +34,21 @@ def flatten_via_chain_splat(rows): def flatten_via_chain_splat_2(rows): return itertools.chain(*rows) +def flatten_via_reduce_add(rows): + return reduce(add, rows) + +def flatten_via_reduce_add_with_default(rows): + return reduce(add, rows, []) + +def flatten_via_reduce_concat(rows): + return reduce(concat, rows) + +def flatten_via_reduce_concat_with_default(rows): + return reduce(concat, rows, []) + +def flatten_via_reduce_full_namespace(rows): + return functools.reduce(operator.add, rows) + # these should not @@ -68,3 +87,9 @@ def flatten_via_chain_without_splat(rows): def flatten_via_chain_from_iterable(rows): return chain.from_iterable(rows) + +def flatten_via_reduce_iadd(rows): + return reduce(iadd, rows, []) + +def flatten_via_reduce_non_empty_default(rows): + return reduce(add, rows, [1, 2, 3]) diff --git a/test/data/err_179.txt b/test/data/err_179.txt index ad67478..7665d46 100644 --- a/test/data/err_179.txt +++ b/test/data/err_179.txt @@ -1,7 +1,12 @@ -test/data/err_179.py:13:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` -test/data/err_179.py:16:12 [FURB179]: Replace `[... for ... in x for ... in ...]` with `list(chain.from_iterable(x))` -test/data/err_179.py:19:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` -test/data/err_179.py:22:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` -test/data/err_179.py:25:12 [FURB179]: Replace `sum(x, [])` with `chain.from_iterable(x)` -test/data/err_179.py:28:12 [FURB179]: Replace `chain(*x)` with `chain.from_iterable(x)` -test/data/err_179.py:31:12 [FURB179]: Replace `itertools.chain(*x)` with `itertools.chain.from_iterable(x)` +test/data/err_179.py:17:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` +test/data/err_179.py:20:12 [FURB179]: Replace `[... for ... in x for ... in ...]` with `list(chain.from_iterable(x))` +test/data/err_179.py:23:12 [FURB179]: Replace `{... for ... in x for ... in ...}` with `set(chain.from_iterable(x))` +test/data/err_179.py:26:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` +test/data/err_179.py:29:12 [FURB179]: Replace `sum(rows, [])` with `chain.from_iterable(rows)` +test/data/err_179.py:32:12 [FURB179]: Replace `chain(*rows)` with `chain.from_iterable(rows)` +test/data/err_179.py:35:12 [FURB179]: Replace `itertools.chain(*rows)` with `itertools.chain.from_iterable(rows)` +test/data/err_179.py:38:12 [FURB179]: Replace `reduce(add, rows)` with `chain.from_iterable(rows)` +test/data/err_179.py:41:12 [FURB179]: Replace `reduce(add, rows, [])` with `chain.from_iterable(rows)` +test/data/err_179.py:44:12 [FURB179]: Replace `reduce(concat, rows)` with `chain.from_iterable(rows)` +test/data/err_179.py:47:12 [FURB179]: Replace `reduce(concat, rows, [])` with `chain.from_iterable(rows)` +test/data/err_179.py:50:12 [FURB179]: Replace `functools.reduce(operator.add, rows)` with `chain.from_iterable(rows)`