diff --git a/docs/changelog.rst b/docs/changelog.rst index daefaa6..cf75067 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,10 @@ Changelog *[CalVer, YY.month.patch](https://calver.org/)* +24.9.1 +====== +- Add :ref:`ASYNC121 ` control-flow-in-taskgroup + 24.8.1 ====== - Add config option ``transform-async-generator-decorators``, to list decorators which diff --git a/docs/rules.rst b/docs/rules.rst index 3eede30..64e6156 100644 --- a/docs/rules.rst +++ b/docs/rules.rst @@ -83,6 +83,9 @@ _`ASYNC120` : await-in-except This will not trigger when :ref:`ASYNC102 ` does, and if you don't care about losing non-cancelled exceptions you could disable this rule. This is currently not able to detect asyncio shields. +_`ASYNC121`: control-flow-in-taskgroup + `return`, `continue`, and `break` inside a :ref:`taskgroup_nursery` can lead to counterintuitive behaviour. Refactor the code to instead cancel the :ref:`cancel_scope` inside the TaskGroup/Nursery and place the statement outside of the TaskGroup/Nursery block. In asyncio a user might expect the statement to have an immediate effect, but it will wait for all tasks to finish before having an effect. See `Trio issue #1493 ` for further issues specific to trio/anyio. + Blocking sync calls in async functions ====================================== diff --git a/flake8_async/__init__.py b/flake8_async/__init__.py index 9e229e6..d4827f8 100644 --- a/flake8_async/__init__.py +++ b/flake8_async/__init__.py @@ -38,7 +38,7 @@ # CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1" -__version__ = "24.8.1" +__version__ = "24.9.1" # taken from https://github.com/Zac-HD/shed diff --git a/flake8_async/visitors/visitors.py b/flake8_async/visitors/visitors.py index d6b1363..6d9ca23 100644 --- a/flake8_async/visitors/visitors.py +++ b/flake8_async/visitors/visitors.py @@ -350,6 +350,59 @@ def visit_Yield(self, node: ast.Yield): visit_Lambda = visit_AsyncFunctionDef +@error_class +class Visitor121(Flake8AsyncVisitor): + error_codes: Mapping[str, str] = { + "ASYNC121": ( + "{0} in a {1} block behaves counterintuitively in several" + " situations. Refactor to have the {0} outside." + ) + } + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.unsafe_stack: list[str] = [] + + def visit_AsyncWith(self, node: ast.AsyncWith): + self.save_state(node, "unsafe_stack", copy=True) + + for item in node.items: + if get_matching_call(item.context_expr, "open_nursery", base="trio"): + self.unsafe_stack.append("nursery") + elif get_matching_call( + item.context_expr, "create_task_group", base="anyio" + ) or get_matching_call(item.context_expr, "TaskGroup", base="asyncio"): + self.unsafe_stack.append("task group") + + def visit_While(self, node: ast.While | ast.For | ast.AsyncFor): + self.save_state(node, "unsafe_stack", copy=True) + self.unsafe_stack.append("loop") + + visit_For = visit_While + visit_AsyncFor = visit_While + + def check_loop_flow(self, node: ast.Continue | ast.Break, statement: str) -> None: + # self.unsafe_stack should never be empty, but no reason not to avoid a crash + # for invalid code. + if self.unsafe_stack and self.unsafe_stack[-1] != "loop": + self.error(node, statement, self.unsafe_stack[-1]) + + def visit_Continue(self, node: ast.Continue) -> None: + self.check_loop_flow(node, "continue") + + def visit_Break(self, node: ast.Break) -> None: + self.check_loop_flow(node, "break") + + def visit_Return(self, node: ast.Return) -> None: + for unsafe_cm in "nursery", "task group": + if unsafe_cm in self.unsafe_stack: + self.error(node, "return", unsafe_cm) + + def visit_FunctionDef(self, node: ast.FunctionDef): + self.save_state(node, "unsafe_stack", copy=True) + self.unsafe_stack = [] + + @error_class_cst class Visitor300(Flake8AsyncVisitor_cst): error_codes: Mapping[str, str] = { diff --git a/tests/eval_files/async121.py b/tests/eval_files/async121.py new file mode 100644 index 0000000..78a6b6c --- /dev/null +++ b/tests/eval_files/async121.py @@ -0,0 +1,105 @@ +# ASYNCIO_NO_ERROR # checked in async121_asyncio.py +# ANYIO_NO_ERROR # checked in async121_anyio.py + +import trio +from typing import Any + + +# To avoid mypy unreachable-statement we wrap control flow calls in if statements +# they should have zero effect on the visitor logic. +def condition() -> bool: + return False + + +def bar() -> Any: ... + + +async def foo_return(): + async with trio.open_nursery(): + if condition(): + return # ASYNC121: 12, "return", "nursery" + while condition(): + return # ASYNC121: 12, "return", "nursery" + + return # safe + + +async def foo_return_nested(): + async with trio.open_nursery(): + + def bar(): + return # safe + + +async def foo_while_safe(): + async with trio.open_nursery(): + while True: + if condition(): + break # safe + if condition(): + continue # safe + continue # safe + + +async def foo_while_unsafe(): + while True: + async with trio.open_nursery(): + if condition(): + continue # ASYNC121: 16, "continue", "nursery" + if condition(): + break # ASYNC121: 16, "break", "nursery" + if condition(): + continue # safe + break # safe + + +async def foo_for_safe(): + async with trio.open_nursery(): + for _ in range(5): + if condition(): + continue # safe + if condition(): + break # safe + + +async def foo_for_unsafe(): + for _ in range(5): + async with trio.open_nursery(): + if condition(): + continue # ASYNC121: 16, "continue", "nursery" + if condition(): + break # ASYNC121: 16, "break", "nursery" + continue # safe + + +async def foo_async_for_safe(): + async with trio.open_nursery(): + async for _ in bar(): + if condition(): + continue # safe + if condition(): + break # safe + + +async def foo_async_for_unsafe(): + async for _ in bar(): + async with trio.open_nursery(): + if condition(): + continue # ASYNC121: 16, "continue", "nursery" + if condition(): + break # ASYNC121: 16, "break", "nursery" + continue # safe + + +# nested nursery +async def foo_nested_nursery(): + async with trio.open_nursery(): + if condition(): + return # ASYNC121: 12, "return", "nursery" + async with trio.open_nursery(): + if condition(): + return # ASYNC121: 16, "return", "nursery" + if condition(): + return # ASYNC121: 12, "return", "nursery" + if condition(): + return # safe diff --git a/tests/eval_files/async121_anyio.py b/tests/eval_files/async121_anyio.py new file mode 100644 index 0000000..12be0ae --- /dev/null +++ b/tests/eval_files/async121_anyio.py @@ -0,0 +1,22 @@ +# ASYNCIO_NO_ERROR # checked in async121_asyncio.py +# TRIO_NO_ERROR # checked in async121.py +# BASE_LIBRARY anyio + +import anyio + + +# To avoid mypy unreachable-statement we wrap control flow calls in if statements +# they should have zero effect on the visitor logic. +def condition() -> bool: + return False + + +# only tests that asyncio.TaskGroup is detected, main tests in async121.py +async def foo_return(): + while True: + async with anyio.create_task_group(): + if condition(): + continue # ASYNC121: 16, "continue", "task group" + if condition(): + break # ASYNC121: 16, "break", "task group" + return # ASYNC121: 12, "return", "task group" diff --git a/tests/eval_files/async121_asyncio.py b/tests/eval_files/async121_asyncio.py new file mode 100644 index 0000000..e8b8e27 --- /dev/null +++ b/tests/eval_files/async121_asyncio.py @@ -0,0 +1,24 @@ +# ANYIO_NO_ERROR +# TRIO_NO_ERROR # checked in async121.py +# BASE_LIBRARY asyncio +# TaskGroup was added in 3.11, we run type checking with 3.9 +# mypy: disable-error-code=attr-defined + +import asyncio + + +# To avoid mypy unreachable-statement we wrap control flow calls in if statements +# they should have zero effect on the visitor logic. +def condition() -> bool: + return False + + +# only tests that asyncio.TaskGroup is detected, main tests in async121.py +async def foo_return(): + while True: + async with asyncio.TaskGroup(): + if condition(): + continue # ASYNC121: 16, "continue", "task group" + if condition(): + break # ASYNC121: 16, "break", "task group" + return # ASYNC121: 12, "return", "task group" diff --git a/tests/test_flake8_async.py b/tests/test_flake8_async.py index dc8b82d..8140750 100644 --- a/tests/test_flake8_async.py +++ b/tests/test_flake8_async.py @@ -478,6 +478,9 @@ def _parse_eval_file( "ASYNC116", "ASYNC117", "ASYNC118", + # opening nurseries & taskgroups can only be done in async context, so ASYNC121 + # doesn't check for it + "ASYNC121", "ASYNC300", "ASYNC912", }