Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2.x engine] Append task run futures only when entering task run engine from flow run context #14439

Merged
merged 8 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,7 @@ def enter_task_run_engine(
return_type: EngineReturnType,
task_runner: Optional[BaseTaskRunner],
mapped: bool,
entering_from_task_run: Optional[bool] = False,
) -> Union[PrefectFuture, Awaitable[PrefectFuture], TaskRun]:
"""Sync entrypoint for task calls"""

Expand Down Expand Up @@ -1402,14 +1403,20 @@ def enter_task_run_engine(
if flow_run_context.timeout_scope and flow_run_context.timeout_scope.cancel_called:
raise TimeoutError("Flow run timed out")

call_arguments = {
"task": task,
"flow_run_context": flow_run_context,
"parameters": parameters,
"wait_for": wait_for,
"return_type": return_type,
"task_runner": task_runner,
}

if not mapped:
call_arguments["entering_from_task_run"] = entering_from_task_run

begin_run = create_call(
begin_task_map if mapped else get_task_call_return_value,
task=task,
flow_run_context=flow_run_context,
parameters=parameters,
wait_for=wait_for,
return_type=return_type,
task_runner=task_runner,
begin_task_map if mapped else get_task_call_return_value, **call_arguments
)

if task.isasync and (
Expand Down Expand Up @@ -1536,6 +1543,7 @@ async def get_task_call_return_value(
return_type: EngineReturnType,
task_runner: Optional[BaseTaskRunner],
extra_task_inputs: Optional[Dict[str, Set[TaskRunInput]]] = None,
entering_from_task_run: Optional[bool] = False,
):
extra_task_inputs = extra_task_inputs or {}

Expand All @@ -1546,6 +1554,7 @@ async def get_task_call_return_value(
wait_for=wait_for,
task_runner=task_runner,
extra_task_inputs=extra_task_inputs,
entering_from_task_run=entering_from_task_run,
)
if return_type == "future":
return future
Expand All @@ -1564,12 +1573,14 @@ async def create_task_run_future(
wait_for: Optional[Iterable[PrefectFuture]],
task_runner: Optional[BaseTaskRunner],
extra_task_inputs: Dict[str, Set[TaskRunInput]],
entering_from_task_run: Optional[bool] = False,
) -> PrefectFuture:
# Default to the flow run's task runner
task_runner = task_runner or flow_run_context.task_runner

# Generate a name for the future
dynamic_key = _dynamic_key_for_task_run(flow_run_context, task)

task_run_name = (
f"{task.name}-{dynamic_key}"
if flow_run_context and flow_run_context.flow_run
Expand Down Expand Up @@ -1604,8 +1615,9 @@ async def create_task_run_future(
)
)

# Track the task run future in the flow run context
flow_run_context.task_run_futures.append(future)
if not entering_from_task_run:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a prefect dev environment running yet where I can test this locally, but only thing I'd double check is that the nested task run still gets its status correctly tracked in the UI and everything even when it's not appended to this list. Without knowing much about the prefect internals, I don't have good intuition here... If you say it works, I believe you!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @thundercat1 - here's how it looks on this branch

image

using this code

from prefect import flow, task

failed = False


@task(name="Nested Flaky Task")
def nested_flaky_task():
    # This task will fail the first time it is run, but will succeed if called a second time
    global failed
    if not failed:
        failed = True
        raise ValueError("Forced task failure")


@task(
    name="Top Task",
    retries=1,
)
def top_task():
    nested_flaky_task()


@flow
def nested_task_flow():
    top_task()


if __name__ == "__main__":
    nested_task_flow()

which matches my expectations.

# Track the task run future in the flow run context
flow_run_context.task_run_futures.append(future)

if task_runner.concurrency_type == TaskConcurrencyType.SEQUENTIAL:
await future._wait()
Expand Down
2 changes: 2 additions & 0 deletions src/prefect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ def __call__(
return_type=return_type,
client=get_client(),
)
entering_from_task_run = bool(TaskRunContext.get())

return enter_task_run_engine(
self,
Expand All @@ -693,6 +694,7 @@ def __call__(
task_runner=SequentialTaskRunner(),
return_type=return_type,
mapped=False,
entering_from_task_run=entering_from_task_run,
)

@overload
Expand Down
29 changes: 29 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4226,6 +4226,35 @@ def my_flow():
assert result == "Failed"
assert count == 2

def test_nested_task_with_retries_on_outer_task(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This unit test looks like what I'd have used to test, so while I didn't test myself if this work then I expect this will resolve teh issue!

"""
Regression test for https://github.com/PrefectHQ/prefect/issues/14390
where the flow run would be marked as failed despite the tasks eventually succeeding.
"""

failed = False

@task
def nested_flaky_task():
# This task will fail the first time it is run, but will succeed if called a second time
nonlocal failed
if not failed:
failed = True
raise ValueError("Forced task failure")

@task(
retries=1,
)
def top_task():
nested_flaky_task()

@flow
def nested_task_flow():
top_task()

result = nested_task_flow()
assert result[0].is_completed()

def test_nested_task_with_retries_on_inner_and_outer_task(self):
count = 0

Expand Down