Skip to content

Commit

Permalink
[2.x engine] Append task run futures only when entering task run engi…
Browse files Browse the repository at this point in the history
…ne from flow run context (#14439)

Co-authored-by: nate nowack <[email protected]>
  • Loading branch information
serinamarie and zzstoatzz authored Jul 9, 2024
1 parent f86bb85 commit 92e1187
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 9 deletions.
30 changes: 21 additions & 9 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,7 @@ def enter_task_run_engine(
return_type: EngineReturnType,
task_runner: Optional[BaseTaskRunner],
mapped: bool,
entering_from_task_run: Optional[bool] = False,
) -> Union[PrefectFuture, Awaitable[PrefectFuture], TaskRun]:
"""Sync entrypoint for task calls"""

Expand Down Expand Up @@ -1402,14 +1403,20 @@ def enter_task_run_engine(
if flow_run_context.timeout_scope and flow_run_context.timeout_scope.cancel_called:
raise TimeoutError("Flow run timed out")

call_arguments = {
"task": task,
"flow_run_context": flow_run_context,
"parameters": parameters,
"wait_for": wait_for,
"return_type": return_type,
"task_runner": task_runner,
}

if not mapped:
call_arguments["entering_from_task_run"] = entering_from_task_run

begin_run = create_call(
begin_task_map if mapped else get_task_call_return_value,
task=task,
flow_run_context=flow_run_context,
parameters=parameters,
wait_for=wait_for,
return_type=return_type,
task_runner=task_runner,
begin_task_map if mapped else get_task_call_return_value, **call_arguments
)

if task.isasync and (
Expand Down Expand Up @@ -1536,6 +1543,7 @@ async def get_task_call_return_value(
return_type: EngineReturnType,
task_runner: Optional[BaseTaskRunner],
extra_task_inputs: Optional[Dict[str, Set[TaskRunInput]]] = None,
entering_from_task_run: Optional[bool] = False,
):
extra_task_inputs = extra_task_inputs or {}

Expand All @@ -1546,6 +1554,7 @@ async def get_task_call_return_value(
wait_for=wait_for,
task_runner=task_runner,
extra_task_inputs=extra_task_inputs,
entering_from_task_run=entering_from_task_run,
)
if return_type == "future":
return future
Expand All @@ -1564,12 +1573,14 @@ async def create_task_run_future(
wait_for: Optional[Iterable[PrefectFuture]],
task_runner: Optional[BaseTaskRunner],
extra_task_inputs: Dict[str, Set[TaskRunInput]],
entering_from_task_run: Optional[bool] = False,
) -> PrefectFuture:
# Default to the flow run's task runner
task_runner = task_runner or flow_run_context.task_runner

# Generate a name for the future
dynamic_key = _dynamic_key_for_task_run(flow_run_context, task)

task_run_name = (
f"{task.name}-{dynamic_key}"
if flow_run_context and flow_run_context.flow_run
Expand Down Expand Up @@ -1604,8 +1615,9 @@ async def create_task_run_future(
)
)

# Track the task run future in the flow run context
flow_run_context.task_run_futures.append(future)
if not entering_from_task_run:
# Track the task run future in the flow run context
flow_run_context.task_run_futures.append(future)

if task_runner.concurrency_type == TaskConcurrencyType.SEQUENTIAL:
await future._wait()
Expand Down
2 changes: 2 additions & 0 deletions src/prefect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ def __call__(
return_type=return_type,
client=get_client(),
)
entering_from_task_run = bool(TaskRunContext.get())

return enter_task_run_engine(
self,
Expand All @@ -693,6 +694,7 @@ def __call__(
task_runner=SequentialTaskRunner(),
return_type=return_type,
mapped=False,
entering_from_task_run=entering_from_task_run,
)

@overload
Expand Down
29 changes: 29 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4226,6 +4226,35 @@ def my_flow():
assert result == "Failed"
assert count == 2

def test_nested_task_with_retries_on_outer_task(self):
"""
Regression test for https://github.com/PrefectHQ/prefect/issues/14390
where the flow run would be marked as failed despite the tasks eventually succeeding.
"""

failed = False

@task
def nested_flaky_task():
# This task will fail the first time it is run, but will succeed if called a second time
nonlocal failed
if not failed:
failed = True
raise ValueError("Forced task failure")

@task(
retries=1,
)
def top_task():
nested_flaky_task()

@flow
def nested_task_flow():
top_task()

result = nested_task_flow()
assert result[0].is_completed()

def test_nested_task_with_retries_on_inner_and_outer_task(self):
count = 0

Expand Down

0 comments on commit 92e1187

Please sign in to comment.