Skip to content

Commit

Permalink
add async121 control-flow-in-taskgroup
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl committed Aug 27, 2024
1 parent 225f15a commit 2742ad1
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Changelog

*[CalVer, YY.month.patch](https://calver.org/)*

24.8.1
======
- Add :ref:`ASYNC121 <async121>` control-flow-in-taskgroup

24.8.1
======
- Add config option ``transform-async-generator-decorators``, to list decorators which
Expand Down
3 changes: 3 additions & 0 deletions docs/rules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ _`ASYNC120` : await-in-except
This will not trigger when :ref:`ASYNC102 <ASYNC102>` does, and if you don't care about losing non-cancelled exceptions you could disable this rule.
This is currently not able to detect asyncio shields.

_`ASYNC121`: control-flow-in-taskgroup
`return`, `continue`, and `break` inside a :ref:`taskgroup_nursery` can lead to counterintuitive behaviour. Refactor the code to instead cancel the :ref:`cancel_scope` and place the statement outside of the TaskGroup/Nursery block. See `trio#1493 <https://github.com/python-trio/trio/issues/1493>`.


Blocking sync calls in async functions
======================================
Expand Down
52 changes: 52 additions & 0 deletions flake8_async/visitors/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,58 @@ def visit_Yield(self, node: ast.Yield):
visit_Lambda = visit_AsyncFunctionDef


@error_class
class Visitor121(Flake8AsyncVisitor):
error_codes: Mapping[str, str] = {
"ASYNC121": (
"{0} in a {1} block behaves counterintuitively in several"
" situations. Refactor to have the {0} outside."
)
}

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.unsafe_stack: list[str] = []

def visit_AsyncWith(self, node: ast.AsyncWith):
self.save_state(node, "unsafe_stack", copy=True)

for item in node.items:
if get_matching_call(item.context_expr, "open_nursery", base="trio"):
self.unsafe_stack.append("nursery")
elif get_matching_call(
item.context_expr, "create_task_group", base="anyio"
):
self.unsafe_stack.append("task group")

def visit_While(self, node: ast.While | ast.For):
self.save_state(node, "unsafe_stack", copy=True)
self.unsafe_stack.append("loop")

visit_For = visit_While

def check_loop_flow(self, node: ast.Continue | ast.Break, statement: str) -> None:
# self.unsafe_stack should never be empty, but no reason not to avoid a crash
# for invalid code.
if self.unsafe_stack and self.unsafe_stack[-1] != "loop":
self.error(node, statement, self.unsafe_stack[-1])

def visit_Continue(self, node: ast.Continue) -> None:
self.check_loop_flow(node, "continue")

def visit_Break(self, node: ast.Break) -> None:
self.check_loop_flow(node, "break")

def visit_Return(self, node: ast.Return) -> None:
for unsafe_cm in "nursery", "task group":
if unsafe_cm in self.unsafe_stack:
self.error(node, "return", unsafe_cm)

def visit_FunctionDef(self, node: ast.FunctionDef):
self.save_state(node, "unsafe_stack", copy=True)
self.unsafe_stack = []


@error_class_cst
class Visitor300(Flake8AsyncVisitor_cst):
error_codes: Mapping[str, str] = {
Expand Down
66 changes: 66 additions & 0 deletions tests/eval_files/async121.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# ASYNCIO_NO_ERROR # not a problem in asyncio
# ANYIO_NO_ERROR # checked in async121_anyio.py

import trio


async def foo_return():
async with trio.open_nursery():
return # ASYNC121: 8, "return", "nursery"


async def foo_return_nested():
async with trio.open_nursery():

def bar():
return # safe


# continue
async def foo_while_continue_safe():
async with trio.open_nursery():
while True:
continue # safe


async def foo_while_continue_unsafe():
while True:
async with trio.open_nursery():
continue # ASYNC121: 12, "continue", "nursery"


async def foo_for_continue_safe():
async with trio.open_nursery():
for _ in range(5):
continue # safe


async def foo_for_continue_unsafe():
for _ in range(5):
async with trio.open_nursery():
continue # ASYNC121: 12, "continue", "nursery"


# break
async def foo_while_break_safe():
async with trio.open_nursery():
while True:
break # safe


async def foo_while_break_unsafe():
while True:
async with trio.open_nursery():
break # ASYNC121: 12, "break", "nursery"


async def foo_for_break_safe():
async with trio.open_nursery():
for _ in range(5):
break # safe


async def foo_for_break_unsafe():
for _ in range(5):
async with trio.open_nursery():
break # ASYNC121: 12, "break", "nursery"
67 changes: 67 additions & 0 deletions tests/eval_files/async121_anyio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# ASYNCIO_NO_ERROR # not a problem in asyncio
# TRIO_NO_ERROR # checked in async121.py
# BASE_LIBRARY anyio

import anyio


async def foo_return():
async with anyio.create_task_group():
return # ASYNC121: 8, "return", "task group"


async def foo_return_nested():
async with anyio.create_task_group():

def bar():
return # safe


# continue
async def foo_while_continue_safe():
async with anyio.create_task_group():
while True:
continue # safe


async def foo_while_continue_unsafe():
while True:
async with anyio.create_task_group():
continue # ASYNC121: 12, "continue", "task group"


async def foo_for_continue_safe():
async with anyio.create_task_group():
for _ in range(5):
continue # safe


async def foo_for_continue_unsafe():
for _ in range(5):
async with anyio.create_task_group():
continue # ASYNC121: 12, "continue", "task group"


# break
async def foo_while_break_safe():
async with anyio.create_task_group():
while True:
break # safe


async def foo_while_break_unsafe():
while True:
async with anyio.create_task_group():
break # ASYNC121: 12, "break", "task group"


async def foo_for_break_safe():
async with anyio.create_task_group():
for _ in range(5):
break # safe


async def foo_for_break_unsafe():
for _ in range(5):
async with anyio.create_task_group():
break # ASYNC121: 12, "break", "task group"
3 changes: 3 additions & 0 deletions tests/test_flake8_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,9 @@ def _parse_eval_file(
"ASYNC116",
"ASYNC117",
"ASYNC118",
# opening nurseries & taskgroups can only be done in async context, so ASYNC121
# doesn't check for it
"ASYNC121",
"ASYNC300",
"ASYNC912",
}
Expand Down

0 comments on commit 2742ad1

Please sign in to comment.