Skip to content

Commit

Permalink
[FEAT]: mermaid formatter (#2619)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 authored Aug 7, 2024
1 parent 702ac73 commit 877efe2
Show file tree
Hide file tree
Showing 20 changed files with 793 additions and 150 deletions.
40 changes: 38 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[dependencies]
common-daft-config = {path = "src/common/daft-config", default-features = false}
common-display = {path = "src/common/display", default-features = false}
common-system-info = {path = "src/common/system-info", default-features = false}
common-tracing = {path = "src/common/tracing", default-features = false}
daft-compression = {path = "src/daft-compression", default-features = false}
Expand Down Expand Up @@ -47,7 +48,8 @@ python = [
"daft-table/python",
"daft-functions/python",
"common-daft-config/python",
"common-system-info/python"
"common-system-info/python",
"common-display/python"
]

[lib]
Expand Down Expand Up @@ -98,6 +100,7 @@ tikv-jemallocator = {version = "0.5.4", features = [
members = [
"src/arrow2",
"src/parquet2",
"src/common/display",
"src/common/error",
"src/common/io-config",
"src/common/treenode",
Expand Down Expand Up @@ -145,6 +148,7 @@ jaq-parse = "1.0.0"
jaq-std = "1.2.0"
num-derive = "0.3.3"
num-traits = "0.2"
pretty_assertions = "1.4.0"
rand = "^0.8"
rayon = "1.10.0"
regex = "1.10.4"
Expand Down
3 changes: 3 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterator

import pyarrow

from daft.dataframe.display import AsciiOptions, MermaidOptions
from daft.execution import physical_plan
from daft.io.scan import ScanOperator
from daft.plan_scheduler.physical_plan_scheduler import PartitionT
Expand Down Expand Up @@ -1554,6 +1555,7 @@ class PhysicalPlanScheduler:
) -> PhysicalPlanScheduler: ...
def num_partitions(self) -> int: ...
def repr_ascii(self, simple: bool) -> str: ...
def display_as(self, options: AsciiOptions | MermaidOptions) -> str: ...
def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.InProgressPhysicalPlan: ...
def run(self, psets: dict[str, list[PartitionT]]) -> Iterator[PyMicroPartition]: ...

Expand Down Expand Up @@ -1683,6 +1685,7 @@ class LogicalPlanBuilder:
def to_physical_plan_scheduler(self, cfg: PyDaftExecutionConfig) -> PhysicalPlanScheduler: ...
def to_adaptive_physical_plan_scheduler(self, cfg: PyDaftExecutionConfig) -> AdaptivePhysicalPlanScheduler: ...
def repr_ascii(self, simple: bool) -> str: ...
def display_as(self, options: AsciiOptions | MermaidOptions) -> str: ...

class NativeExecutor:
@staticmethod
Expand Down
30 changes: 26 additions & 4 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,41 +124,63 @@ def _result(self) -> Optional[PartitionSet]:
return self._result_cache.value

@DataframePublicAPI
def explain(self, show_all: bool = False, simple: bool = False, file: Optional[io.IOBase] = None) -> None:
def explain(
self, show_all: bool = False, format: str = "ascii", simple: bool = False, file: Optional[io.IOBase] = None
) -> Any:
"""Prints the (logical and physical) plans that will be executed to produce this DataFrame.
Defaults to showing the unoptimized logical plan. Use ``show_all=True`` to show the unoptimized logical plan,
the optimized logical plan, and the physical plan.
Args:
show_all (bool): Whether to show the optimized logical plan and the physical plan in addition to the
unoptimized logical plan.
format (str): The format to print the plan in. one of 'ascii' or 'mermaid'
simple (bool): Whether to only show the type of op for each node in the plan, rather than showing details
of how each op is configured.
file (Optional[io.IOBase]): Location to print the output to, or defaults to None which defaults to the default location for
print (in Python, that should be sys.stdout)
"""
is_cached = self._result_cache is not None
if format == "mermaid":
from daft.dataframe.display import MermaidFormatter
from daft.utils import in_notebook

instance = MermaidFormatter(self.__builder, show_all, simple, is_cached)
if file is not None:
# if we are printing to a file, we print the markdown representation of the plan
text = instance._repr_markdown_()
print(text, file=file)
if in_notebook():
# if in a notebook, we return the class instance and let jupyter display it
return instance
else:
# if we are not in a notebook, we return the raw markdown instead of the class instance
return repr(instance)

print_to_file = partial(print, file=file)

if self._result_cache is not None:
print_to_file("Result is cached and will skip computation\n")
print_to_file(self._builder.pretty_print(simple))
print_to_file(self._builder.pretty_print(simple, format=format))

print_to_file("However here is the logical plan used to produce this result:\n", file=file)

builder = self.__builder
print_to_file("== Unoptimized Logical Plan ==\n")
print_to_file(builder.pretty_print(simple))
print_to_file(builder.pretty_print(simple, format=format))
if show_all:
print_to_file("\n== Optimized Logical Plan ==\n")
builder = builder.optimize()
print_to_file(builder.pretty_print(simple))
print_to_file("\n== Physical Plan ==\n")
physical_plan_scheduler = builder.to_physical_plan_scheduler(get_context().daft_execution_config)
print_to_file(physical_plan_scheduler.pretty_print(simple))
print_to_file(physical_plan_scheduler.pretty_print(simple, format=format))
else:
print_to_file(
"\n \nSet `show_all=True` to also see the Optimized and Physical plans. This will run the query optimizer.",
)
return None

def num_partitions(self) -> int:
daft_execution_config = get_context().daft_execution_config
Expand Down
93 changes: 93 additions & 0 deletions daft/dataframe/display.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Optional, Union

from daft.context import get_context


class AsciiOptions:
simple: bool

def __init__(self, simple: bool = False):
self.simple = simple


class SubgraphOptions:
name: str
subgraph_id: str

def __init__(self, name: str, subgraph_id: str):
self.name = name
self.subgraph_id = subgraph_id


class MermaidOptions:
simple: bool
subgraph_options: Optional[SubgraphOptions]

def __init__(self, simple: bool = False, subgraph_options: Optional[SubgraphOptions] = None):
self.simple = simple
self.subgraph_options = subgraph_options

def with_subgraph_options(self, name: str, subgraph_id: str):
opts = MermaidOptions(self.simple, SubgraphOptions(name, subgraph_id))

return opts


def make_display_options(simple: bool, format: str) -> Union[MermaidOptions, AsciiOptions]:
if format == "ascii":
return AsciiOptions(simple)
elif format == "mermaid":
return MermaidOptions(simple)
else:
raise ValueError(f"Unknown format: {format}")


class MermaidFormatter:
def __init__(self, builder, show_all: bool = False, simple: bool = False, is_cached: bool = False):
self.builder = builder
self.show_all = show_all
self.simple = simple
self.is_cached = is_cached

def _repr_markdown_(self):
builder = self.builder
output = ""
display_opts = MermaidOptions(simple=self.simple)

# TODO handle cached plans

if self.show_all:
output = "```mermaid\n"
output += "flowchart TD\n"
output += builder._builder.display_as(
display_opts.with_subgraph_options(name="Unoptimized LogicalPlan", subgraph_id="unoptimized")
)
output += "\n"

builder = builder.optimize()
output += builder._builder.display_as(
display_opts.with_subgraph_options(name="Optimized LogicalPlan", subgraph_id="optimized")
)
output += "\n"
physical_plan_scheduler = builder.to_physical_plan_scheduler(get_context().daft_execution_config)
output += physical_plan_scheduler._scheduler.display_as(
display_opts.with_subgraph_options(name="Physical Plan", subgraph_id="physical")
)
output += "\n"
output += "unoptimized --> optimized\n"
output += "optimized --> physical\n"
output += "```\n"

else:
output = "```mermaid\n"
output += builder._builder.display_as(display_opts)
output += "\n"
output += "```\n"
output += (
"Set `show_all=True` to also see the Optimized and Physical plans. This will run the query optimizer."
)

return output

def __repr__(self):
return self._repr_markdown_()
12 changes: 6 additions & 6 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ScanOperatorHandle,
)
from daft.daft import LogicalPlanBuilder as _LogicalPlanBuilder
from daft.dataframe.display import make_display_options
from daft.expressions import Expression, col
from daft.logical.schema import Schema
from daft.runners.partitioning import PartitionCacheEntry
Expand Down Expand Up @@ -68,17 +69,16 @@ def schema(self) -> Schema:
pyschema = self._builder.schema()
return Schema._from_pyschema(pyschema)

def pretty_print(self, simple: bool = False) -> str:
def pretty_print(self, simple: bool = False, format: str = "ascii") -> str:
"""
Pretty prints the current underlying logical plan.
"""
if simple:
return self._builder.repr_ascii(simple=True)
else:
return repr(self)
display_opts = make_display_options(simple, format)
return self._builder.display_as(display_opts)

def __repr__(self) -> str:
return self._builder.repr_ascii(simple=False)
display_opts = make_display_options(simple=False, format="ascii")
return self._builder.display_as(display_opts)

def optimize(self) -> LogicalPlanBuilder:
"""
Expand Down
9 changes: 4 additions & 5 deletions daft/plan_scheduler/physical_plan_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from daft.daft import (
PyDaftExecutionConfig,
)
from daft.dataframe.display import make_display_options
from daft.execution import physical_plan
from daft.logical.builder import LogicalPlanBuilder
from daft.runners.partitioning import (
Expand All @@ -33,14 +34,12 @@ def from_logical_plan_builder(
def num_partitions(self) -> int:
return self._scheduler.num_partitions()

def pretty_print(self, simple: bool = False) -> str:
def pretty_print(self, simple: bool = False, format: str = "ascii") -> str:
"""
Pretty prints the current underlying physical plan.
"""
if simple:
return self._scheduler.repr_ascii(simple=True)
else:
return repr(self)
display_opts = make_display_options(simple, format)
return self._scheduler.display_as(display_opts)

def __repr__(self) -> str:
return self._scheduler.repr_ascii(simple=False)
Expand Down
14 changes: 14 additions & 0 deletions daft/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@
ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric())


def in_notebook():
"""Check if we are in a Jupyter notebook."""
try:
from IPython import get_ipython

if "IPKernelApp" not in get_ipython().config: # pragma: no cover
return False
except ImportError:
return False
except AttributeError:
return False
return True


def pydict_to_rows(pydict: dict[str, list]) -> list[frozenset[tuple[str, Any]]]:
"""Converts a dataframe pydict to a list of rows representation.
Expand Down
Loading

0 comments on commit 877efe2

Please sign in to comment.