From 7d2a910330f67ccfccd0563f62d7df39e45162fb Mon Sep 17 00:00:00 2001 From: jakekaplan <40362401+jakekaplan@users.noreply.github.com> Date: Thu, 17 Aug 2023 15:41:41 -0400 Subject: [PATCH] Add `Flow.visualize()` (#10417) Co-authored-by: Serina Grill <42048900+serinamarie@users.noreply.github.com> --- requirements.txt | 1 + src/prefect/flows.py | 72 ++++- src/prefect/settings.py | 5 + src/prefect/tasks.py | 35 +++ src/prefect/utilities/visualization.py | 201 ++++++++++++++ tests/utilities/test_visualization.py | 355 +++++++++++++++++++++++++ 6 files changed, 668 insertions(+), 1 deletion(-) create mode 100644 src/prefect/utilities/visualization.py create mode 100644 tests/utilities/test_visualization.py diff --git a/requirements.txt b/requirements.txt index a81cd5b629e4..63483e5e5412 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ dateparser >= 1.1.1 docker >= 4.0 fastapi >= 0.93 fsspec >= 2022.5.0 +graphviz >= 0.20.1 griffe >= 0.20.0 httpx[http2] >= 0.23, != 0.23.2 importlib_metadata >= 4.4; python_version < '3.10' diff --git a/src/prefect/flows.py b/src/prefect/flows.py index b14229c851db..97b9a88652e6 100644 --- a/src/prefect/flows.py +++ b/src/prefect/flows.py @@ -37,6 +37,7 @@ from prefect.client.schemas.objects import Flow as FlowSchema from prefect.client.schemas.objects import FlowRun from prefect.context import PrefectObjectRegistry, registry_from_script + from prefect.exceptions import ( MissingFlowError, ParameterTypeError, @@ -48,11 +49,12 @@ from prefect.settings import ( PREFECT_FLOW_DEFAULT_RETRIES, PREFECT_FLOW_DEFAULT_RETRY_DELAY_SECONDS, + PREFECT_UNIT_TEST_MODE, ) from prefect.states import State from prefect.task_runners import BaseTaskRunner, ConcurrentTaskRunner from prefect.utilities.annotations import NotSet -from prefect.utilities.asyncutils import is_async_fn +from prefect.utilities.asyncutils import is_async_fn, sync_compatible from prefect.utilities.callables import ( get_call_parameters, parameter_schema, @@ -62,7 +64,19 @@ from prefect.utilities.collections import listrepr from prefect.utilities.hashing import file_hash from prefect.utilities.importtools import import_object +from prefect.utilities.visualization import ( + GraphvizExecutableNotFoundError, + GraphvizImportError, + TaskVizTracker, + build_task_dependencies, + visualize_task_dependencies, + FlowVisualizationError, + get_task_viz_tracker, + track_viz_task, + VisualizationUnsupportedError, +) +from prefect._internal.compatibility.experimental import experimental T = TypeVar("T") # Generic type var for capturing the inner return type of async funcs R = TypeVar("R") # The return type of the user's function @@ -537,6 +551,12 @@ def __call__( return_type = "state" if return_state else "result" + task_viz_tracker = get_task_viz_tracker() + if task_viz_tracker: + # this is a subflow, for now return a single task and do not go further + # we can add support for exploring subflows for tasks in the future. + return track_viz_task(self.isasync, self.name, parameters) + return enter_flow_run_engine_from_flow_call( self, parameters, @@ -589,6 +609,56 @@ def _run( return_type="state", ) + @sync_compatible + @experimental(feature="The visualize feature", group="visualize", stacklevel=1) + async def visualize(self, *args, **kwargs): + """ + Generates a graphviz object representing the current flow. In IPython notebooks, + it's rendered inline, otherwise in a new window as a PNG. + + Raises: + - ImportError: If `graphviz` isn't installed. + - GraphvizExecutableNotFoundError: If the `dot` executable isn't found. + - FlowVisualizationError: If the flow can't be visualized for any other reason. + """ + if not PREFECT_UNIT_TEST_MODE: + warnings.warn( + "`flow.visualize()` will execute code inside of your flow that is not" + " decorated with `@task` or `@flow`." + ) + + try: + with TaskVizTracker() as tracker: + if self.isasync: + await self.fn(*args, **kwargs) + else: + self.fn(*args, **kwargs) + + graph = build_task_dependencies(tracker) + + visualize_task_dependencies(graph, self.name) + + except GraphvizImportError: + raise + except GraphvizExecutableNotFoundError: + raise + except VisualizationUnsupportedError: + raise + except FlowVisualizationError: + raise + except Exception as e: + msg = ( + "It's possible you are trying to visualize a flow that contains " + "code that directly interacts with the result of a task" + " inside of the flow. \nTry passing a `viz_return_value` " + "to the task decorator, e.g. `@task(viz_return_value=[1, 2, 3]).`" + ) + + new_exception = type(e)(str(e) + "\n" + msg) + # Copy traceback information from the original exception + new_exception.__traceback__ = e.__traceback__ + raise new_exception + @overload def flow(__fn: Callable[P, R]) -> Flow[P, R]: diff --git a/src/prefect/settings.py b/src/prefect/settings.py index 25128e686ad2..f30537298a92 100644 --- a/src/prefect/settings.py +++ b/src/prefect/settings.py @@ -1205,6 +1205,11 @@ def default_cloud_ui_url(settings, value): Whether or not to warn when experimental Prefect workers are used. """ +PREFECT_EXPERIMENTAL_WARN_VISUALIZE = Setting(bool, default=True) +""" +Whether or not to warn when experimental Prefect visualize is used. +""" + PREFECT_WORKER_HEARTBEAT_SECONDS = Setting(float, default=30) """ Number of seconds a worker should wait between sending a heartbeat. diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index 6dfd92b3a00a..1c274bc555ed 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -46,6 +46,11 @@ ) from prefect.utilities.hashing import hash_objects from prefect.utilities.importtools import to_qualified_name +from prefect.utilities.visualization import ( + get_task_viz_tracker, + track_viz_task, + VisualizationUnsupportedError, +) if TYPE_CHECKING: from prefect.context import TaskRunContext @@ -168,6 +173,7 @@ class Task(Generic[P, R]): execution with matching cache key is used. on_failure: An optional list of callables to run when the task enters a failed state. on_completion: An optional list of callables to run when the task enters a completed state. + viz_return_value: An optional value to return when the task dependency tree is visualized. """ # NOTE: These parameters (types, defaults, and docstrings) should be duplicated @@ -204,6 +210,7 @@ def __init__( refresh_cache: Optional[bool] = None, on_completion: Optional[List[Callable[["Task", TaskRun, State], None]]] = None, on_failure: Optional[List[Callable[["Task", TaskRun, State], None]]] = None, + viz_return_value: Optional[Any] = None, ): # Validate if hook passed is list and contains callables hook_categories = [on_completion, on_failure] @@ -331,6 +338,7 @@ def __init__( ) self.on_completion = on_completion self.on_failure = on_failure + self.viz_return_value = viz_return_value def with_options( self, @@ -361,6 +369,7 @@ def with_options( refresh_cache: Optional[bool] = NotSet, on_completion: Optional[List[Callable[["Task", TaskRun, State], None]]] = None, on_failure: Optional[List[Callable[["Task", TaskRun, State], None]]] = None, + viz_return_value: Optional[Any] = None, ): """ Create a new task from the current object, updating provided options. @@ -395,6 +404,7 @@ def with_options( refresh_cache: A new option for enabling or disabling cache refresh. on_completion: A new list of callables to run when the task enters a completed state. on_failure: A new list of callables to run when the task enters a failed state. + viz_return_value: An optional value to return when the task dependency tree is visualized. Returns: A new `Task` instance. @@ -482,6 +492,7 @@ def with_options( ), on_completion=on_completion or self.on_completion, on_failure=on_failure or self.on_failure, + viz_return_value=viz_return_value or self.viz_return_value, ) @overload @@ -530,6 +541,12 @@ def __call__( return_type = "state" if return_state else "result" + task_run_tracker = get_task_viz_tracker() + if task_run_tracker: + return track_viz_task( + self.isasync, self.name, parameters, self.viz_return_value + ) + return enter_task_run_engine( self, parameters=parameters, @@ -727,6 +744,12 @@ def submit( parameters = get_call_parameters(self.fn, args, kwargs) return_type = "state" if return_state else "future" + task_viz_tracker = get_task_viz_tracker() + if task_viz_tracker: + raise VisualizationUnsupportedError( + "`task.submit()` is not currently supported by `flow.visualize()`" + ) + return enter_task_run_engine( self, parameters=parameters, @@ -898,6 +921,12 @@ def map( parameters = get_call_parameters(self.fn, args, kwargs, apply_defaults=False) return_type = "state" if return_state else "future" + task_viz_tracker = get_task_viz_tracker() + if task_viz_tracker: + raise VisualizationUnsupportedError( + "`task.map()` is not currently supported by `flow.visualize()`" + ) + return enter_task_run_engine( self, parameters=parameters, @@ -941,6 +970,7 @@ def task( refresh_cache: Optional[bool] = None, on_completion: Optional[List[Callable[["Task", TaskRun, State], None]]] = None, on_failure: Optional[List[Callable[["Task", TaskRun, State], None]]] = None, + viz_return_value: Any = None, ) -> Callable[[Callable[P, R]], Task[P, R]]: ... @@ -973,6 +1003,7 @@ def task( refresh_cache: Optional[bool] = None, on_completion: Optional[List[Callable[["Task", TaskRun, State], None]]] = None, on_failure: Optional[List[Callable[["Task", TaskRun, State], None]]] = None, + viz_return_value: Any = None, ): """ Decorator to designate a function as a task in a Prefect workflow. @@ -1028,6 +1059,7 @@ def task( execution with matching cache key is used. on_failure: An optional list of callables to run when the task enters a failed state. on_completion: An optional list of callables to run when the task enters a completed state. + viz_return_value: An optional value to return when the task dependency tree is visualized. Returns: A callable `Task` object which, when called, will submit the task for execution. @@ -1077,6 +1109,7 @@ def task( >>> def my_task(): >>> return "hello" """ + if __fn: return cast( Task[P, R], @@ -1102,6 +1135,7 @@ def task( refresh_cache=refresh_cache, on_completion=on_completion, on_failure=on_failure, + viz_return_value=viz_return_value, ), ) else: @@ -1129,5 +1163,6 @@ def task( refresh_cache=refresh_cache, on_completion=on_completion, on_failure=on_failure, + viz_return_value=viz_return_value, ), ) diff --git a/src/prefect/utilities/visualization.py b/src/prefect/utilities/visualization.py new file mode 100644 index 000000000000..03d70c3ce8e6 --- /dev/null +++ b/src/prefect/utilities/visualization.py @@ -0,0 +1,201 @@ +""" +Utilities for working with Flow.visualize() +""" +from functools import partial +from typing import Any, List, Optional + +import graphviz + +from prefect._internal.concurrency.api import from_async + + +class FlowVisualizationError(Exception): + pass + + +class VisualizationUnsupportedError(Exception): + pass + + +class TaskVizTrackerState: + current = None + + +class GraphvizImportError(Exception): + pass + + +class GraphvizExecutableNotFoundError(Exception): + pass + + +def get_task_viz_tracker(): + return TaskVizTrackerState.current + + +def track_viz_task( + is_async: bool, + task_name: str, + parameters: dict, + viz_return_value: Optional[Any] = None, +): + """Return a result if sync otherwise return a coroutine that returns the result""" + if is_async: + return from_async.wait_for_call_in_loop_thread( + partial(_track_viz_task, task_name, parameters, viz_return_value) + ) + else: + return _track_viz_task(task_name, parameters, viz_return_value) + + +def _track_viz_task( + task_name, + parameters, + viz_return_value=None, +) -> Any: + task_run_tracker = get_task_viz_tracker() + if task_run_tracker: + upstream_tasks = [] + for k, v in parameters.items(): + if isinstance(v, VizTask): + upstream_tasks.append(v) + # if it's an object that we've already seen, + # we can use the object id to find if there is a trackable task + # if so, add it to the upstream tasks + elif id(v) in task_run_tracker.object_id_to_task: + upstream_tasks.append(task_run_tracker.object_id_to_task[id(v)]) + + viz_task = VizTask( + name=task_name, + upstream_tasks=upstream_tasks, + ) + task_run_tracker.add_task(viz_task) + + if viz_return_value: + task_run_tracker.link_viz_return_value_to_viz_task( + viz_return_value, viz_task + ) + return viz_return_value + + return viz_task + + +class VizTask: + def __init__( + self, + name: str, + upstream_tasks: Optional[List["VizTask"]] = None, + ): + self.name = name + self.upstream_tasks = upstream_tasks if upstream_tasks else [] + + +class TaskVizTracker: + def __init__(self): + self.tasks = [] + self.dynamic_task_counter = {} + self.object_id_to_task = {} + + def add_task(self, task: VizTask): + if task.name not in self.dynamic_task_counter: + self.dynamic_task_counter[task.name] = 0 + else: + self.dynamic_task_counter[task.name] += 1 + + task.name = f"{task.name}-{self.dynamic_task_counter[task.name]}" + self.tasks.append(task) + + def __enter__(self): + TaskVizTrackerState.current = self + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + TaskVizTrackerState.current = None + + def link_viz_return_value_to_viz_task( + self, viz_return_value: Any, viz_task: VizTask + ) -> None: + """ + We cannot track booleans, Ellipsis, None, NotImplemented, or the integers from -5 to 256 + because they are singletons. + """ + from prefect.engine import UNTRACKABLE_TYPES + + if (type(viz_return_value) in UNTRACKABLE_TYPES) or ( + isinstance(viz_return_value, int) and (-5 <= viz_return_value <= 256) + ): + return + self.object_id_to_task[id(viz_return_value)] = viz_task + + +def build_task_dependencies(task_run_tracker: TaskVizTracker): + """ + Constructs a Graphviz directed graph object that represents the dependencies + between tasks in the given TaskVizTracker. + + Parameters: + - task_run_tracker (TaskVizTracker): An object containing tasks and their + dependencies. + + Returns: + - graphviz.Digraph: A directed graph object depicting the relationships and + dependencies between tasks. + + Raises: + - GraphvizImportError: If there's an ImportError related to graphviz. + - FlowVisualizationError: If there's any other error during the visualization + process or if return values of tasks are directly accessed without + specifying a `viz_return_value`. + """ + try: + g = graphviz.Digraph() + for task in task_run_tracker.tasks: + g.node(task.name) + for upstream in task.upstream_tasks: + g.edge(upstream.name, task.name) + return g + except ImportError as exc: + raise GraphvizImportError from exc + except Exception: + raise FlowVisualizationError( + "Something went wrong building the flow's visualization." + " If you're interacting with the return value of a task" + " directly inside of your flow, you must set a set a `viz_return_value`" + ", for example `@task(viz_return_value=[1, 2, 3])`." + ) + + +def visualize_task_dependencies(graph: graphviz.Digraph, flow_run_name: str): + """ + Renders and displays a Graphviz directed graph representing task dependencies. + + The graph is rendered in PNG format and saved with the name specified by + flow_run_name. After rendering, the visualization is opened and displayed. + + Parameters: + - graph (graphviz.Digraph): The directed graph object to visualize. + - flow_run_name (str): The name to use when saving the rendered graph image. + + Raises: + - GraphvizExecutableNotFoundError: If Graphviz isn't found on the system. + - FlowVisualizationError: If there's any other error during the visualization + process or if return values of tasks are directly accessed without + specifying a `viz_return_value`. + """ + try: + graph.render(filename=flow_run_name, view=True, format="png", cleanup=True) + except graphviz.backend.ExecutableNotFound as exc: + msg = ( + "It appears you do not have Graphviz installed, or it is not on your " + "PATH. Please install Graphviz from http://www.graphviz.org/download/. " + "Note: Just installing the `graphviz` python package is not " + "sufficient." + ) + raise GraphvizExecutableNotFoundError(msg) from exc + except Exception: + raise FlowVisualizationError( + "Something went wrong building the flow's visualization." + " If you're interacting with the return value of a task" + " directly inside of your flow, you must set a set a `viz_return_value`" + ", for example `@task(viz_return_value=[1, 2, 3])`." + ) diff --git a/tests/utilities/test_visualization.py b/tests/utilities/test_visualization.py new file mode 100644 index 000000000000..0829006cad3d --- /dev/null +++ b/tests/utilities/test_visualization.py @@ -0,0 +1,355 @@ +import pytest + +from unittest.mock import Mock, MagicMock + +from prefect import flow, task +from prefect.utilities.visualization import ( + TaskVizTracker, + VizTask, + _track_viz_task, + get_task_viz_tracker, + VisualizationUnsupportedError, +) + +from prefect.settings import PREFECT_EXPERIMENTAL_WARN_VISUALIZE, temporary_settings + + +@pytest.fixture(autouse=True) +def disable_warn_visualize(): + """Disable the warning that is printed when a flow is visualized""" + with temporary_settings({PREFECT_EXPERIMENTAL_WARN_VISUALIZE: 0}): + yield + + +class TestTaskVizTracker: + async def test_get_task_run_tracker(self): + with TaskVizTracker() as tracker: + tracker_in_ctx = get_task_viz_tracker() + assert tracker_in_ctx + assert id(tracker) == id(tracker_in_ctx) + + async def test_get_task_run_tracker_outside_ctx(self): + tracker_outside_ctx = get_task_viz_tracker() + assert not tracker_outside_ctx + + with TaskVizTracker() as _: + pass + + tracker_outside_ctx = get_task_viz_tracker() + assert not tracker_outside_ctx + + async def test_add_task(self): + with TaskVizTracker() as tracker: + assert len(tracker.tasks) == 0 + + tracker.add_task(VizTask("my_task")) + assert len(tracker.tasks) == 1 + assert tracker.tasks[0].name == "my_task-0" + + tracker.add_task(VizTask("my_task")) + assert len(tracker.tasks) == 2 + assert tracker.tasks[1].name == "my_task-1" + + tracker.add_task(VizTask("my_other_task")) + assert len(tracker.tasks) == 3 + assert tracker.tasks[2].name == "my_other_task-0" + + @pytest.mark.parametrize( + "trackable", + [ + ("my_return_value", True), + ([1, 2, 3], True), + (500, True), + (None, False), + (1, False), + ], + ) + async def test_link_viz_return_value_to_viz_task(self, trackable): + value, is_trackable = trackable + with TaskVizTracker() as tracker: + trackable_task = VizTask("my_task") + tracker.link_viz_return_value_to_viz_task(value, trackable_task) + if is_trackable: + assert tracker.object_id_to_task[id(value)] == trackable_task + else: + assert id(value) not in tracker.object_id_to_task + + +class TestTrackTaskRun: + async def test_track_task_run_outside_ctx(self, monkeypatch): + mock = Mock() + monkeypatch.setattr( + "prefect.utilities.visualization.TaskVizTracker.add_task", mock + ) + _track_viz_task( + "my_task", + {"a": 1}, + ) + assert mock.call_count == 0 + + async def test_track_task_run_in_ctx(self, monkeypatch): + mock = Mock() + monkeypatch.setattr( + "prefect.utilities.visualization.TaskVizTracker.add_task", mock + ) + with TaskVizTracker(): + _track_viz_task( + "my_task", + {"a": 1}, + ) + assert mock.call_count == 1 + + async def test_track_task_run(self): + with TaskVizTracker() as tracker: + res = _track_viz_task("my_task", {"a": 1}) + assert isinstance(res, VizTask) + assert res.name == "my_task-0" + assert res.upstream_tasks == [] + + assert len(tracker.tasks) == 1 + assert res == tracker.tasks[0] + + async def test_track_task_run_with_upstream_task(self): + with TaskVizTracker() as tracker: + upstream_task = VizTask("upstream_task") + _track_viz_task("my_task", {"a": upstream_task}) + + assert len(tracker.tasks) == 1 + tracked_task = tracker.tasks[0] + assert tracked_task.name == "my_task-0" + assert len(tracked_task.upstream_tasks) == 1 + assert upstream_task in tracked_task.upstream_tasks + + async def test_track_task_run_returns_viz_return_value(self): + s = "my_return_value" + + with TaskVizTracker(): + res = _track_viz_task( + "upstream_task_with_value", {"a": 1}, viz_return_value=s + ) + assert res == s + assert id(res) == id(s) + + async def test_track_task_run_links_upstream_obj(self): + s = "my_return_value" + + with TaskVizTracker() as tracker: + _track_viz_task("upstream_task_with_value", {"a": 1}, viz_return_value=s) + + assert len(tracker.tasks) == 1 + assert len(tracker.object_id_to_task) == 1 + assert tracker.tasks[0].name == "upstream_task_with_value-0" + assert tracker.tasks[0].upstream_tasks == [] + + _track_viz_task("my_task", {"a": s}) + + assert len(tracker.tasks) == 2 + assert len(tracker.object_id_to_task) == 1 + assert tracker.tasks[1].name == "my_task-0" + assert tracker.tasks[1].upstream_tasks == [tracker.tasks[0]] + + +async def test_flow_visualize_doesnt_support_task_map(): + @task + def add_one(n): + return n + 1 + + @flow + def add_flow(): + add_one.map([1, 2, 3]) + + with pytest.raises(VisualizationUnsupportedError, match="task.map()"): + await add_flow.visualize() + + +async def test_flow_visualize_doesnt_support_task_submit(): + @task + def add_one(n): + return n + 1 + + @flow + def add_flow(): + add_one.submit(1) + + with pytest.raises(VisualizationUnsupportedError, match="task.submit()"): + await add_flow.visualize() + + +@task(viz_return_value=-10) +def sync_task_a(): + return "Sync Result A" + + +@task +def sync_task_b(input_data): + return f"Sync Result B from {input_data}" + + +@task +async def async_task_a(): + return "Async Result A" + + +@task +async def async_task_b(input_data): + return f"Async Result B from {input_data}" + + +@task(viz_return_value=5) +def untrackable_task_result(): + return "Untrackable Task Result" + + +@flow +def simple_sync_flow(): + a = sync_task_a() + sync_task_b(a) + + +@flow +async def flow_with_mixed_tasks(): + a = sync_task_a() + await async_task_b(a) + a = sync_task_a() + + +@flow +async def simple_async_flow_with_async_tasks(): + a = await async_task_a() + await async_task_b(a) + + +@flow +async def simple_async_flow_with_sync_tasks(): + a = sync_task_a() + sync_task_b(a) + + +@flow +async def async_flow_with_subflow(): + a = sync_task_a() + await simple_async_flow_with_sync_tasks() + sync_task_b(a) + + +@flow +def flow_with_task_interaction(): + a = sync_task_a() + b = a + 1 + sync_task_b(b) + + +@flow +def flow_with_flow_params(x=1): + a = sync_task_a() + b = a + x + sync_task_b(b) + + +@flow +def flow_with_untrackable_task_result(): + res = untrackable_task_result() + sync_task_b(res) + + +class TestFlowVisualise: + @pytest.mark.parametrize( + "test_flow", + [ + simple_sync_flow, + simple_async_flow_with_async_tasks, + simple_async_flow_with_sync_tasks, + async_flow_with_subflow, + flow_with_task_interaction, + flow_with_mixed_tasks, + flow_with_untrackable_task_result, + flow_with_flow_params, + ], + ) + def test_visualize_does_not_raise(self, test_flow, monkeypatch): + monkeypatch.setattr( + "prefect.flows.visualize_task_dependencies", MagicMock(return_value=None) + ) + + test_flow.visualize() + + @pytest.mark.parametrize( + "test_flow, expected_nodes", + [ + ( + simple_sync_flow, + { + '\t"sync_task_b-0"\n', + '\t"sync_task_a-0"\n', + '\t"sync_task_a-0" -> "sync_task_b-0"\n', + }, + ), + ( + simple_async_flow_with_async_tasks, + { + '\t"async_task_a-0"\n', + '\t"async_task_b-0"\n', + '\t"async_task_a-0" -> "async_task_b-0"\n', + }, + ), + ( + simple_async_flow_with_sync_tasks, + { + '\t"sync_task_a-0"\n', + '\t"sync_task_b-0"\n', + '\t"sync_task_a-0" -> "sync_task_b-0"\n', + }, + ), + ( + async_flow_with_subflow, + { + '\t"sync_task_a-0" -> "sync_task_b-0"\n', + '\t"sync_task_b-0"\n', + '\t"simple-async-flow-with-sync-tasks-0"\n', + '\t"sync_task_a-0"\n', + }, + ), + ( + flow_with_task_interaction, + { + '\t"sync_task_a-0"\n', + '\t"sync_task_b-0"\n', + }, + ), + ( + flow_with_mixed_tasks, + { + '\t"sync_task_a-0"\n', + '\t"async_task_b-0"\n', + '\t"sync_task_a-1"\n', + '\t"sync_task_a-0" -> "async_task_b-0"\n', + }, + ), + ( + flow_with_untrackable_task_result, + { + '\t"untrackable_task_result-0"\n', + '\t"sync_task_b-0"\n', + }, + ), + ( + flow_with_flow_params, + { + '\t"sync_task_a-0"\n', + '\t"sync_task_b-0"\n', + }, + ), + ], + ) + def test_visualize_graph_contents(self, test_flow, expected_nodes, monkeypatch): + mock_visualize = MagicMock(return_value=None) + monkeypatch.setattr("prefect.flows.visualize_task_dependencies", mock_visualize) + + test_flow.visualize() + graph = mock_visualize.call_args[0][0] + + actual_nodes = set(graph.body) + + assert ( + actual_nodes == expected_nodes + ), f"Expected nodes {expected_nodes} but found {actual_nodes}"