From 3435f63fbc201fbaa93730b53391a3eb442ab77d Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Thu, 29 Aug 2024 15:04:41 +0200
Subject: [PATCH] enable rule for asyncio, add more details to rule
explanation. Extend tests to be more thorough with state management.
---
docs/rules.rst | 2 +-
flake8_async/visitors/visitors.py | 2 +-
tests/eval_files/async121.py | 43 ++++++++++++++---
tests/eval_files/async121_anyio.py | 2 +-
tests/eval_files/async121_asyncio.py | 69 ++++++++++++++++++++++++++++
5 files changed, 109 insertions(+), 9 deletions(-)
create mode 100644 tests/eval_files/async121_asyncio.py
diff --git a/docs/rules.rst b/docs/rules.rst
index 9db42c1..64e6156 100644
--- a/docs/rules.rst
+++ b/docs/rules.rst
@@ -84,7 +84,7 @@ _`ASYNC120` : await-in-except
This is currently not able to detect asyncio shields.
_`ASYNC121`: control-flow-in-taskgroup
- `return`, `continue`, and `break` inside a :ref:`taskgroup_nursery` can lead to counterintuitive behaviour. Refactor the code to instead cancel the :ref:`cancel_scope` inside the TaskGroup/Nursery and place the statement outside of the TaskGroup/Nursery block. See `Trio issue #1493 `.
+ `return`, `continue`, and `break` inside a :ref:`taskgroup_nursery` can lead to counterintuitive behaviour. Refactor the code to instead cancel the :ref:`cancel_scope` inside the TaskGroup/Nursery and place the statement outside of the TaskGroup/Nursery block. In asyncio a user might expect the statement to have an immediate effect, but it will wait for all tasks to finish before having an effect. See `Trio issue #1493 ` for further issues specific to trio/anyio.
Blocking sync calls in async functions
diff --git a/flake8_async/visitors/visitors.py b/flake8_async/visitors/visitors.py
index f5ef4c1..a553a0a 100644
--- a/flake8_async/visitors/visitors.py
+++ b/flake8_async/visitors/visitors.py
@@ -371,7 +371,7 @@ def visit_AsyncWith(self, node: ast.AsyncWith):
self.unsafe_stack.append("nursery")
elif get_matching_call(
item.context_expr, "create_task_group", base="anyio"
- ):
+ ) or get_matching_call(item.context_expr, "TaskGroup", base="asyncio"):
self.unsafe_stack.append("task group")
def visit_While(self, node: ast.While | ast.For):
diff --git a/tests/eval_files/async121.py b/tests/eval_files/async121.py
index 5986c8d..0beb851 100644
--- a/tests/eval_files/async121.py
+++ b/tests/eval_files/async121.py
@@ -1,12 +1,21 @@
-# ASYNCIO_NO_ERROR # not a problem in asyncio
+# ASYNCIO_NO_ERROR # checked in async121_asyncio.py
# ANYIO_NO_ERROR # checked in async121_anyio.py
import trio
+def condition() -> bool:
+ return False
+
+
async def foo_return():
async with trio.open_nursery():
- return # ASYNC121: 8, "return", "nursery"
+ if condition():
+ return # ASYNC121: 12, "return", "nursery"
+ while condition():
+ return # ASYNC121: 12, "return", "nursery"
+
+ return # safe
async def foo_return_nested():
@@ -26,7 +35,9 @@ async def foo_while_continue_safe():
async def foo_while_continue_unsafe():
while True:
async with trio.open_nursery():
- continue # ASYNC121: 12, "continue", "nursery"
+ if condition():
+ continue # ASYNC121: 16, "continue", "nursery"
+ continue # safe
async def foo_for_continue_safe():
@@ -38,7 +49,9 @@ async def foo_for_continue_safe():
async def foo_for_continue_unsafe():
for _ in range(5):
async with trio.open_nursery():
- continue # ASYNC121: 12, "continue", "nursery"
+ if condition():
+ continue # ASYNC121: 16, "continue", "nursery"
+ continue # safe
# break
@@ -51,7 +64,9 @@ async def foo_while_break_safe():
async def foo_while_break_unsafe():
while True:
async with trio.open_nursery():
- break # ASYNC121: 12, "break", "nursery"
+ if condition():
+ break # ASYNC121: 16, "break", "nursery"
+ continue # safe
async def foo_for_break_safe():
@@ -63,4 +78,20 @@ async def foo_for_break_safe():
async def foo_for_break_unsafe():
for _ in range(5):
async with trio.open_nursery():
- break # ASYNC121: 12, "break", "nursery"
+ if condition():
+ break # ASYNC121: 16, "break", "nursery"
+ continue # safe
+
+
+# nested nursery
+async def foo_nested_nursery():
+ async with trio.open_nursery():
+ if condition():
+ return # ASYNC121: 12, "return", "nursery"
+ async with trio.open_nursery():
+ if condition():
+ return # ASYNC121: 16, "return", "nursery"
+ if condition():
+ return # ASYNC121: 12, "return", "nursery"
+ if condition():
+ return # safe
diff --git a/tests/eval_files/async121_anyio.py b/tests/eval_files/async121_anyio.py
index 01b5244..2505e23 100644
--- a/tests/eval_files/async121_anyio.py
+++ b/tests/eval_files/async121_anyio.py
@@ -1,4 +1,4 @@
-# ASYNCIO_NO_ERROR # not a problem in asyncio
+# ASYNCIO_NO_ERROR # checked in async121_asyncio.py
# TRIO_NO_ERROR # checked in async121.py
# BASE_LIBRARY anyio
diff --git a/tests/eval_files/async121_asyncio.py b/tests/eval_files/async121_asyncio.py
new file mode 100644
index 0000000..80b6126
--- /dev/null
+++ b/tests/eval_files/async121_asyncio.py
@@ -0,0 +1,69 @@
+# ANYIO_NO_ERROR
+# TRIO_NO_ERROR # checked in async121.py
+# BASE_LIBRARY asyncio
+# TaskGroup was added in 3.11, we run type checking with 3.9
+# mypy: disable-error-code=attr-defined
+
+import asyncio
+
+
+async def foo_return():
+ async with asyncio.TaskGroup():
+ return # ASYNC121: 8, "return", "task group"
+
+
+async def foo_return_nested():
+ async with asyncio.TaskGroup():
+
+ def bar():
+ return # safe
+
+
+# continue
+async def foo_while_continue_safe():
+ async with asyncio.TaskGroup():
+ while True:
+ continue # safe
+
+
+async def foo_while_continue_unsafe():
+ while True:
+ async with asyncio.TaskGroup():
+ continue # ASYNC121: 12, "continue", "task group"
+
+
+async def foo_for_continue_safe():
+ async with asyncio.TaskGroup():
+ for _ in range(5):
+ continue # safe
+
+
+async def foo_for_continue_unsafe():
+ for _ in range(5):
+ async with asyncio.TaskGroup():
+ continue # ASYNC121: 12, "continue", "task group"
+
+
+# break
+async def foo_while_break_safe():
+ async with asyncio.TaskGroup():
+ while True:
+ break # safe
+
+
+async def foo_while_break_unsafe():
+ while True:
+ async with asyncio.TaskGroup():
+ break # ASYNC121: 12, "break", "task group"
+
+
+async def foo_for_break_safe():
+ async with asyncio.TaskGroup():
+ for _ in range(5):
+ break # safe
+
+
+async def foo_for_break_unsafe():
+ for _ in range(5):
+ async with asyncio.TaskGroup():
+ break # ASYNC121: 12, "break", "task group"