Skip to content

Commit

Permalink
Add Flow.visualize() (#10417)
Browse files Browse the repository at this point in the history
Co-authored-by: Serina Grill <[email protected]>
  • Loading branch information
jakekaplan and serinamarie authored Aug 17, 2023
1 parent aaf5d25 commit 7d2a910
Show file tree
Hide file tree
Showing 6 changed files with 668 additions and 1 deletion.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dateparser >= 1.1.1
docker >= 4.0
fastapi >= 0.93
fsspec >= 2022.5.0
graphviz >= 0.20.1
griffe >= 0.20.0
httpx[http2] >= 0.23, != 0.23.2
importlib_metadata >= 4.4; python_version < '3.10'
Expand Down
72 changes: 71 additions & 1 deletion src/prefect/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from prefect.client.schemas.objects import Flow as FlowSchema
from prefect.client.schemas.objects import FlowRun
from prefect.context import PrefectObjectRegistry, registry_from_script

from prefect.exceptions import (
MissingFlowError,
ParameterTypeError,
Expand All @@ -48,11 +49,12 @@
from prefect.settings import (
PREFECT_FLOW_DEFAULT_RETRIES,
PREFECT_FLOW_DEFAULT_RETRY_DELAY_SECONDS,
PREFECT_UNIT_TEST_MODE,
)
from prefect.states import State
from prefect.task_runners import BaseTaskRunner, ConcurrentTaskRunner
from prefect.utilities.annotations import NotSet
from prefect.utilities.asyncutils import is_async_fn
from prefect.utilities.asyncutils import is_async_fn, sync_compatible
from prefect.utilities.callables import (
get_call_parameters,
parameter_schema,
Expand All @@ -62,7 +64,19 @@
from prefect.utilities.collections import listrepr
from prefect.utilities.hashing import file_hash
from prefect.utilities.importtools import import_object
from prefect.utilities.visualization import (
GraphvizExecutableNotFoundError,
GraphvizImportError,
TaskVizTracker,
build_task_dependencies,
visualize_task_dependencies,
FlowVisualizationError,
get_task_viz_tracker,
track_viz_task,
VisualizationUnsupportedError,
)

from prefect._internal.compatibility.experimental import experimental

T = TypeVar("T") # Generic type var for capturing the inner return type of async funcs
R = TypeVar("R") # The return type of the user's function
Expand Down Expand Up @@ -537,6 +551,12 @@ def __call__(

return_type = "state" if return_state else "result"

task_viz_tracker = get_task_viz_tracker()
if task_viz_tracker:
# this is a subflow, for now return a single task and do not go further
# we can add support for exploring subflows for tasks in the future.
return track_viz_task(self.isasync, self.name, parameters)

return enter_flow_run_engine_from_flow_call(
self,
parameters,
Expand Down Expand Up @@ -589,6 +609,56 @@ def _run(
return_type="state",
)

@sync_compatible
@experimental(feature="The visualize feature", group="visualize", stacklevel=1)
async def visualize(self, *args, **kwargs):
"""
Generates a graphviz object representing the current flow. In IPython notebooks,
it's rendered inline, otherwise in a new window as a PNG.
Raises:
- ImportError: If `graphviz` isn't installed.
- GraphvizExecutableNotFoundError: If the `dot` executable isn't found.
- FlowVisualizationError: If the flow can't be visualized for any other reason.
"""
if not PREFECT_UNIT_TEST_MODE:
warnings.warn(
"`flow.visualize()` will execute code inside of your flow that is not"
" decorated with `@task` or `@flow`."
)

try:
with TaskVizTracker() as tracker:
if self.isasync:
await self.fn(*args, **kwargs)
else:
self.fn(*args, **kwargs)

graph = build_task_dependencies(tracker)

visualize_task_dependencies(graph, self.name)

except GraphvizImportError:
raise
except GraphvizExecutableNotFoundError:
raise
except VisualizationUnsupportedError:
raise
except FlowVisualizationError:
raise
except Exception as e:
msg = (
"It's possible you are trying to visualize a flow that contains "
"code that directly interacts with the result of a task"
" inside of the flow. \nTry passing a `viz_return_value` "
"to the task decorator, e.g. `@task(viz_return_value=[1, 2, 3]).`"
)

new_exception = type(e)(str(e) + "\n" + msg)
# Copy traceback information from the original exception
new_exception.__traceback__ = e.__traceback__
raise new_exception


@overload
def flow(__fn: Callable[P, R]) -> Flow[P, R]:
Expand Down
5 changes: 5 additions & 0 deletions src/prefect/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,11 @@ def default_cloud_ui_url(settings, value):
Whether or not to warn when experimental Prefect workers are used.
"""

PREFECT_EXPERIMENTAL_WARN_VISUALIZE = Setting(bool, default=True)
"""
Whether or not to warn when experimental Prefect visualize is used.
"""

PREFECT_WORKER_HEARTBEAT_SECONDS = Setting(float, default=30)
"""
Number of seconds a worker should wait between sending a heartbeat.
Expand Down
35 changes: 35 additions & 0 deletions src/prefect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@
)
from prefect.utilities.hashing import hash_objects
from prefect.utilities.importtools import to_qualified_name
from prefect.utilities.visualization import (
get_task_viz_tracker,
track_viz_task,
VisualizationUnsupportedError,
)

if TYPE_CHECKING:
from prefect.context import TaskRunContext
Expand Down Expand Up @@ -168,6 +173,7 @@ class Task(Generic[P, R]):
execution with matching cache key is used.
on_failure: An optional list of callables to run when the task enters a failed state.
on_completion: An optional list of callables to run when the task enters a completed state.
viz_return_value: An optional value to return when the task dependency tree is visualized.
"""

# NOTE: These parameters (types, defaults, and docstrings) should be duplicated
Expand Down Expand Up @@ -204,6 +210,7 @@ def __init__(
refresh_cache: Optional[bool] = None,
on_completion: Optional[List[Callable[["Task", TaskRun, State], None]]] = None,
on_failure: Optional[List[Callable[["Task", TaskRun, State], None]]] = None,
viz_return_value: Optional[Any] = None,
):
# Validate if hook passed is list and contains callables
hook_categories = [on_completion, on_failure]
Expand Down Expand Up @@ -331,6 +338,7 @@ def __init__(
)
self.on_completion = on_completion
self.on_failure = on_failure
self.viz_return_value = viz_return_value

def with_options(
self,
Expand Down Expand Up @@ -361,6 +369,7 @@ def with_options(
refresh_cache: Optional[bool] = NotSet,
on_completion: Optional[List[Callable[["Task", TaskRun, State], None]]] = None,
on_failure: Optional[List[Callable[["Task", TaskRun, State], None]]] = None,
viz_return_value: Optional[Any] = None,
):
"""
Create a new task from the current object, updating provided options.
Expand Down Expand Up @@ -395,6 +404,7 @@ def with_options(
refresh_cache: A new option for enabling or disabling cache refresh.
on_completion: A new list of callables to run when the task enters a completed state.
on_failure: A new list of callables to run when the task enters a failed state.
viz_return_value: An optional value to return when the task dependency tree is visualized.
Returns:
A new `Task` instance.
Expand Down Expand Up @@ -482,6 +492,7 @@ def with_options(
),
on_completion=on_completion or self.on_completion,
on_failure=on_failure or self.on_failure,
viz_return_value=viz_return_value or self.viz_return_value,
)

@overload
Expand Down Expand Up @@ -530,6 +541,12 @@ def __call__(

return_type = "state" if return_state else "result"

task_run_tracker = get_task_viz_tracker()
if task_run_tracker:
return track_viz_task(
self.isasync, self.name, parameters, self.viz_return_value
)

return enter_task_run_engine(
self,
parameters=parameters,
Expand Down Expand Up @@ -727,6 +744,12 @@ def submit(
parameters = get_call_parameters(self.fn, args, kwargs)
return_type = "state" if return_state else "future"

task_viz_tracker = get_task_viz_tracker()
if task_viz_tracker:
raise VisualizationUnsupportedError(
"`task.submit()` is not currently supported by `flow.visualize()`"
)

return enter_task_run_engine(
self,
parameters=parameters,
Expand Down Expand Up @@ -898,6 +921,12 @@ def map(
parameters = get_call_parameters(self.fn, args, kwargs, apply_defaults=False)
return_type = "state" if return_state else "future"

task_viz_tracker = get_task_viz_tracker()
if task_viz_tracker:
raise VisualizationUnsupportedError(
"`task.map()` is not currently supported by `flow.visualize()`"
)

return enter_task_run_engine(
self,
parameters=parameters,
Expand Down Expand Up @@ -941,6 +970,7 @@ def task(
refresh_cache: Optional[bool] = None,
on_completion: Optional[List[Callable[["Task", TaskRun, State], None]]] = None,
on_failure: Optional[List[Callable[["Task", TaskRun, State], None]]] = None,
viz_return_value: Any = None,
) -> Callable[[Callable[P, R]], Task[P, R]]:
...

Expand Down Expand Up @@ -973,6 +1003,7 @@ def task(
refresh_cache: Optional[bool] = None,
on_completion: Optional[List[Callable[["Task", TaskRun, State], None]]] = None,
on_failure: Optional[List[Callable[["Task", TaskRun, State], None]]] = None,
viz_return_value: Any = None,
):
"""
Decorator to designate a function as a task in a Prefect workflow.
Expand Down Expand Up @@ -1028,6 +1059,7 @@ def task(
execution with matching cache key is used.
on_failure: An optional list of callables to run when the task enters a failed state.
on_completion: An optional list of callables to run when the task enters a completed state.
viz_return_value: An optional value to return when the task dependency tree is visualized.
Returns:
A callable `Task` object which, when called, will submit the task for execution.
Expand Down Expand Up @@ -1077,6 +1109,7 @@ def task(
>>> def my_task():
>>> return "hello"
"""

if __fn:
return cast(
Task[P, R],
Expand All @@ -1102,6 +1135,7 @@ def task(
refresh_cache=refresh_cache,
on_completion=on_completion,
on_failure=on_failure,
viz_return_value=viz_return_value,
),
)
else:
Expand Down Expand Up @@ -1129,5 +1163,6 @@ def task(
refresh_cache=refresh_cache,
on_completion=on_completion,
on_failure=on_failure,
viz_return_value=viz_return_value,
),
)
Loading

0 comments on commit 7d2a910

Please sign in to comment.