Skip to content

Commit

Permalink
Replace str "output" by a dummy Op in the clients of the FunctionGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 30, 2024
1 parent 2143d85 commit a2e61fe
Show file tree
Hide file tree
Showing 18 changed files with 134 additions and 129 deletions.
19 changes: 10 additions & 9 deletions pytensor/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pytensor.graph.basic import Variable, io_toposort
from pytensor.graph.destroyhandler import DestroyHandler
from pytensor.graph.features import AlreadyThere, BadOptimization
from pytensor.graph.fg import Output
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.link.basic import Container, LocalLinker
Expand Down Expand Up @@ -628,7 +629,11 @@ def _is_used_in_graph(fgraph, var):
True if `var` is used by another node in the graph.
"""
return not (fgraph.clients[var] == [("output", 1)] or fgraph.clients[var] == [])
return any(
client
for client, _ in fgraph.clients[var]
if not isinstance(client.owner.op, Output)
)


def _check_strides_match(a, b, warn_err, op):
Expand Down Expand Up @@ -978,7 +983,7 @@ def _check_preallocated_output(
# disable memory checks in that mode, since they were already run.
try:
changed_inner_mode = False
if isinstance(getattr(node, "op", None), HasInnerGraph):
if isinstance(node.op, HasInnerGraph):
fn = node.op.fn
if not fn or not hasattr(fn, "maker") or not hasattr(fn.maker, "mode"):
_logger.warning(f"Expected pytensor function not found in {node.op}.fn")
Expand Down Expand Up @@ -1133,18 +1138,14 @@ class _FunctionGraphEvent:

def __init__(self, kind, node, idx=None, reason=None):
self.kind = kind
if node == "output":
self.node = "output"
self.op = "output"
else:
self.node = node
self.op = node.op
self.node = node
self.op = node.op
self.idx = idx
self.reason = str(reason)

def __str__(self):
if self.kind == "change":
if self.op != "output":
if not isinstance(self.op, Output):
msg = str(len(self.node.inputs))
else:
msg = ""
Expand Down
2 changes: 0 additions & 2 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ def view_tree_set(fgraph, v, treeset):
"""
treeset.add(v)
for cl, v_input_pos_to_cl in fgraph.clients[v]:
if cl == "output":
continue
vmap = cl.op.view_map
dmap = cl.op.destroy_map
for opos, iposlist in chain(vmap.items(), dmap.items()):
Expand Down
6 changes: 3 additions & 3 deletions pytensor/compile/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


if TYPE_CHECKING:
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.fg import FunctionGraph, Output


@contextmanager
Expand Down Expand Up @@ -1055,7 +1055,7 @@ def count_minimum_peak(node_list, fgraph, nodes_mem):
executable_nodes = set()
for var in fgraph.inputs:
for c, _ in fgraph.clients[var]:
if c != "output":
if not isinstance(c.op, Output):
deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps):
executable_nodes.add(c)
Expand Down Expand Up @@ -1183,7 +1183,7 @@ def min_memory_generator(executable_nodes, viewed_by, view_of):

for var in node.outputs:
for c, _ in fgraph.clients[var]:
if c != "output":
if not isinstance(c.op, Output):
deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps):
new_exec_nodes.add(c)
Expand Down
7 changes: 3 additions & 4 deletions pytensor/graph/destroyhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pytensor.configdefaults import config
from pytensor.graph.basic import Constant
from pytensor.graph.features import AlreadyThere, Bookkeeper
from pytensor.graph.fg import Output
from pytensor.graph.utils import InconsistencyError
from pytensor.misc.ordered_set import OrderedSet

Expand Down Expand Up @@ -401,8 +402,6 @@ def has_destroyers(protected_list):
def recursive_destroys_finder(protected_var):
# protected_var is the idx'th input of app.
for app, idx in fgraph.clients[protected_var]:
if app == "output":
continue
destroy_maps = app.op.destroy_map.values()
# If True means that the apply node, destroys the protected_var.
if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
Expand Down Expand Up @@ -575,10 +574,10 @@ def on_prune(self, fgraph, app, reason):

def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
"""
app.inputs[i] changed from old_r to new_r.
node.inputs[i] changed from old_r to new_r.
"""
if app == "output":
if isinstance(app.op, Output):
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# considered 'outputs' of the graph.
pass
Expand Down
125 changes: 64 additions & 61 deletions pytensor/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from collections import OrderedDict
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Any, Literal, Union, cast
from typing import Any, Union, cast

import pytensor
from pytensor.configdefaults import config
Expand All @@ -19,15 +19,30 @@
)
from pytensor.graph.basic import as_string as graph_as_string
from pytensor.graph.features import AlreadyThere, Feature, ReplaceValidate
from pytensor.graph.op import Op
from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError
from pytensor.misc.ordered_set import OrderedSet


if TYPE_CHECKING:
from pytensor.graph.op import Op
ClientType = tuple[Apply, int]

ApplyOrOutput = Apply | Literal["output"]
ClientType = tuple[ApplyOrOutput, int]

class Output(Op):
"""A dummy `Op` that represents an output variable in a `FunctionGraph`."""

__props__ = ("idx",)

def __init__(self, idx):
self.idx = idx

def make_node(self, inp):
return Apply(self, [inp], [])

def perform(self, node, inputs, outputs):
raise RuntimeError("Output Ops should never be evaluated")

def __str__(self):
return f"output[{self.idx}]"


class FunctionGraph(MetaObject):
Expand Down Expand Up @@ -157,7 +172,7 @@ def add_output(
"""Add a new variable as an output to this `FunctionGraph`."""
self.outputs.append(var)
self.import_var(var, reason=reason, import_missing=import_missing)
self.clients[var].append(("output", len(self.outputs) - 1))
self.clients[var].append((Output(len(self.outputs) - 1).make_node(var), 0))

def add_input(self, var: Variable, check: bool = True) -> None:
"""Add a new variable as an input to this `FunctionGraph`.
Expand Down Expand Up @@ -198,10 +213,8 @@ def add_client(self, var: Variable, new_client: ClientType) -> None:
A ``(node, i)`` pair such that ``node.inputs[i]`` is `var`.
"""
if not isinstance(new_client[0], Apply) and new_client[0] != "output":
raise TypeError(
'The first entry of `new_client` must be an `Apply` node or the string `"output"`'
)
if not isinstance(new_client[0], Apply):
raise TypeError("The first entry of `new_client` must be an `Apply` node")
self.clients[var].append(new_client)

def remove_client(
Expand Down Expand Up @@ -382,7 +395,7 @@ def import_node(

def change_node_input(
self,
node: ApplyOrOutput,
node: Apply,
i: int,
new_var: Variable,
reason: str | None = None,
Expand All @@ -401,9 +414,7 @@ def change_node_input(
Parameters
----------
node
The node for which an input is to be changed. If the value is
the string ``"output"`` then the ``self.outputs`` will be used
instead of ``node.inputs``.
The node for which an input is to be changed.
i
The index in `node.inputs` that we want to change.
new_var
Expand All @@ -417,23 +428,16 @@ def change_node_input(
narrowed and would otherwise fail this check.
"""
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if node == "output":
r = self.outputs[i]
if check and not r.type.is_super(new_var.type):
raise TypeError(
f"The type of the replacement ({new_var.type}) must be "
f"compatible with the type of the original Variable ({r.type})."
)
r = node.inputs[i]
if check and not r.type.is_super(new_var.type):
raise TypeError(
f"The type of the replacement ({new_var.type}) must be "
f"compatible with the type of the original Variable ({r.type})."
)
node.inputs[i] = new_var

if isinstance(node.op, Output):
self.outputs[i] = new_var
else:
assert isinstance(node, Apply)
r = node.inputs[i]
if check and not r.type.is_super(new_var.type):
raise TypeError(
f"The type of the replacement ({new_var.type}) must be "
f"compatible with the type of the original Variable ({r.type})."
)
node.inputs[i] = new_var

if r is new_var:
return
Expand Down Expand Up @@ -521,7 +525,7 @@ def replace_all(self, pairs: Iterable[tuple[Variable, Variable]], **kwargs) -> N
def _remove_output(self, idx: int):
"""Remove the output at index `idx` and update the indices in the clients entries.
`FunctionGraph.clients` contains entries like ``("output", i)`` under
`FunctionGraph.clients` contains entries like ``(output(var), i)`` under
each output variable in `FunctionGraph.outputs`. The ``i`` values
correspond to each output's location within the `FunctionGraph.outputs`
list, so, when an output is removed from the graph, all these entries
Expand All @@ -533,16 +537,23 @@ def _remove_output(self, idx: int):
which they're contained are already being updated in-place.
"""
old_idx_mappings = tuple((out, i) for i, out in enumerate(self.outputs))

self.outputs.pop(idx)

new_idx = 0
for out, old_idx in old_idx_mappings:
if old_idx == idx:
continue
out_clients = self.clients[out]
arrow: ClientType = ("output", old_idx)
arrow_idx = out_clients.index(arrow)
out_clients[arrow_idx] = ("output", new_idx)

if old_idx != new_idx:
out_clients = self.clients[out]
[client_out_idx] = [
i
for i, (out_client, _) in enumerate(out_clients)
if isinstance(out_client.op, Output)
and out_client.op.idx == old_idx
]
out_clients[client_out_idx] = (Output(new_idx).make_node(out), 0)
new_idx += 1

def remove_node(self, node: Apply, reason: str | None = None):
Expand Down Expand Up @@ -570,8 +581,8 @@ def remove_node(self, node: Apply, reason: str | None = None):
while out_clients:
out_client, out_idx = out_clients.pop()

if out_client == "output":
self._remove_output(out_idx)
if isinstance(out_client.op, Output):
self._remove_output(out_client.op.idx)

# TODO: We could short-circuit all of the graph walking and
# clear everything at once when all the outputs are gone.
Expand Down Expand Up @@ -630,32 +641,26 @@ def remove_node(self, node: Apply, reason: str | None = None):
self.execute_callbacks("on_prune", node, reason)

def remove_input(self, input_idx: int, reason: str | None = None):
"""Remove the input at index `input_idx`."""
"""Remove the input at index `input_idx`.
Any node that depended on such input will also be removed.
"""
var = self.inputs.pop(input_idx)

for client, idx in list(self.clients[var]):
if client == "output":
out_var = self.outputs[idx]
out_node = out_var.owner
if out_node is None:
assert out_var in self.inputs
self.outputs.pop(idx)
continue
client_node = out_node
else:
assert isinstance(client, Apply)
client_node = client

self.remove_node(client_node, reason=reason)
self.remove_node(client, reason=reason)

def remove_output(self, output_idx: int, reason: str | None = None):
"""Remove the output at index `input_idx`."""

var = self.outputs[output_idx]
self._remove_output(output_idx)
self.remove_client(
var, ("output", output_idx), reason=reason, remove_if_empty=True
old_out_client = next(
(client, i)
for client, i in self.clients[var]
if isinstance(client.op, Output) and client.op.idx == output_idx
)
self.remove_client(var, old_out_client, reason=reason, remove_if_empty=True)

def attach_feature(self, feature: Feature) -> None:
"""Add a ``graph.features.Feature`` to this function graph and trigger its ``on_attach`` callback."""
Expand Down Expand Up @@ -832,19 +837,17 @@ def check_integrity(self) -> None:
):
raise Exception(f"Undeclared input: {variable}")
for cl_node, i in self.clients[variable]:
if cl_node == "output":
if self.outputs[i] is not variable:
if isinstance(cl_node.op, Output):
out_idx = cl_node.op.idx
if self.outputs[out_idx] is not variable:
raise Exception(
f"Inconsistent clients list: {variable}, {self.outputs[i]}"
f"Inconsistent clients list: {variable}, {self.outputs[out_idx]}"
)
continue

assert isinstance(cl_node, Apply)

if cl_node not in nodes:
elif cl_node not in nodes:
raise Exception(
f"Client not in FunctionGraph: {variable}, {(cl_node, i)}"
)

if cl_node.inputs[i] is not variable:
raise Exception(
f"Inconsistent clients list: {variable}, {cl_node.inputs[i]}"
Expand Down
8 changes: 3 additions & 5 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
vars_between,
)
from pytensor.graph.features import AlreadyThere, Feature, NodeFinder
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.graph.utils import AssocList, InconsistencyError
from pytensor.misc.ordered_set import OrderedSet
Expand Down Expand Up @@ -738,7 +738,7 @@ def apply(self, fgraph):
if any(
i in flatten(c.op.destroy_map.values())
for c, i in clients
if c != "output" and c.op.destroy_map
if c.op.destroy_map
):
continue

Expand Down Expand Up @@ -1626,8 +1626,6 @@ def transform(self, fgraph, node, get_nodes=True):

if get_nodes and self.get_nodes is not None:
for real_node in self.get_nodes(fgraph, node):
if real_node == "output":
continue
ret = self.transform(fgraph, real_node, get_nodes=False)
if ret is not False and ret is not None:
return dict(zip(real_node.outputs, ret))
Expand Down Expand Up @@ -2407,7 +2405,7 @@ def importer(node):
if self.tracks_on_change_inputs:

def chin_(node, i, r, new_r, reason):
if node is not current_node and not isinstance(node, str):
if node is not current_node and not isinstance(node.op, Output):
q.append(node)

chin = chin_
Expand Down
Loading

0 comments on commit a2e61fe

Please sign in to comment.