From a2e61fede129e70aeff4978503c3662c10b2b389 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 30 May 2024 11:34:03 +0200 Subject: [PATCH] Replace str "output" by a dummy Op in the clients of the FunctionGraph --- pytensor/compile/debugmode.py | 19 ++-- pytensor/compile/function/types.py | 2 - pytensor/compile/profiling.py | 6 +- pytensor/graph/destroyhandler.py | 7 +- pytensor/graph/fg.py | 125 +++++++++++----------- pytensor/graph/rewriting/basic.py | 8 +- pytensor/graph/rewriting/utils.py | 4 +- pytensor/link/c/basic.py | 4 +- pytensor/link/vm.py | 3 +- pytensor/printing.py | 13 ++- pytensor/scan/rewriting.py | 8 +- pytensor/tensor/basic.py | 4 +- pytensor/tensor/random/rewriting/basic.py | 14 ++- pytensor/tensor/rewriting/elemwise.py | 7 +- pytensor/tensor/rewriting/linalg.py | 2 - pytensor/tensor/rewriting/math.py | 7 +- pytensor/tensor/rewriting/shape.py | 9 +- tests/graph/test_fg.py | 21 ++-- 18 files changed, 134 insertions(+), 129 deletions(-) diff --git a/pytensor/compile/debugmode.py b/pytensor/compile/debugmode.py index f86288a450..1f76d4dc81 100644 --- a/pytensor/compile/debugmode.py +++ b/pytensor/compile/debugmode.py @@ -30,6 +30,7 @@ from pytensor.graph.basic import Variable, io_toposort from pytensor.graph.destroyhandler import DestroyHandler from pytensor.graph.features import AlreadyThere, BadOptimization +from pytensor.graph.fg import Output from pytensor.graph.op import HasInnerGraph, Op from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.link.basic import Container, LocalLinker @@ -628,7 +629,11 @@ def _is_used_in_graph(fgraph, var): True if `var` is used by another node in the graph. """ - return not (fgraph.clients[var] == [("output", 1)] or fgraph.clients[var] == []) + return any( + client + for client, _ in fgraph.clients[var] + if not isinstance(client.owner.op, Output) + ) def _check_strides_match(a, b, warn_err, op): @@ -978,7 +983,7 @@ def _check_preallocated_output( # disable memory checks in that mode, since they were already run. try: changed_inner_mode = False - if isinstance(getattr(node, "op", None), HasInnerGraph): + if isinstance(node.op, HasInnerGraph): fn = node.op.fn if not fn or not hasattr(fn, "maker") or not hasattr(fn.maker, "mode"): _logger.warning(f"Expected pytensor function not found in {node.op}.fn") @@ -1133,18 +1138,14 @@ class _FunctionGraphEvent: def __init__(self, kind, node, idx=None, reason=None): self.kind = kind - if node == "output": - self.node = "output" - self.op = "output" - else: - self.node = node - self.op = node.op + self.node = node + self.op = node.op self.idx = idx self.reason = str(reason) def __str__(self): if self.kind == "change": - if self.op != "output": + if not isinstance(self.op, Output): msg = str(len(self.node.inputs)) else: msg = "" diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index c221d7cf41..71ed842b61 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -77,8 +77,6 @@ def view_tree_set(fgraph, v, treeset): """ treeset.add(v) for cl, v_input_pos_to_cl in fgraph.clients[v]: - if cl == "output": - continue vmap = cl.op.view_map dmap = cl.op.destroy_map for opos, iposlist in chain(vmap.items(), dmap.items()): diff --git a/pytensor/compile/profiling.py b/pytensor/compile/profiling.py index 56a88ecfe3..991a738b01 100644 --- a/pytensor/compile/profiling.py +++ b/pytensor/compile/profiling.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: - from pytensor.graph.fg import FunctionGraph + from pytensor.graph.fg import FunctionGraph, Output @contextmanager @@ -1055,7 +1055,7 @@ def count_minimum_peak(node_list, fgraph, nodes_mem): executable_nodes = set() for var in fgraph.inputs: for c, _ in fgraph.clients[var]: - if c != "output": + if not isinstance(c.op, Output): deps = c.inputs + destroy_dependencies[c] if all(compute_map[v][0] for v in deps): executable_nodes.add(c) @@ -1183,7 +1183,7 @@ def min_memory_generator(executable_nodes, viewed_by, view_of): for var in node.outputs: for c, _ in fgraph.clients[var]: - if c != "output": + if not isinstance(c.op, Output): deps = c.inputs + destroy_dependencies[c] if all(compute_map[v][0] for v in deps): new_exec_nodes.add(c) diff --git a/pytensor/graph/destroyhandler.py b/pytensor/graph/destroyhandler.py index e90dc01a26..80afaae259 100644 --- a/pytensor/graph/destroyhandler.py +++ b/pytensor/graph/destroyhandler.py @@ -11,6 +11,7 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Constant from pytensor.graph.features import AlreadyThere, Bookkeeper +from pytensor.graph.fg import Output from pytensor.graph.utils import InconsistencyError from pytensor.misc.ordered_set import OrderedSet @@ -401,8 +402,6 @@ def has_destroyers(protected_list): def recursive_destroys_finder(protected_var): # protected_var is the idx'th input of app. for app, idx in fgraph.clients[protected_var]: - if app == "output": - continue destroy_maps = app.op.destroy_map.values() # If True means that the apply node, destroys the protected_var. if idx in [dmap for sublist in destroy_maps for dmap in sublist]: @@ -575,10 +574,10 @@ def on_prune(self, fgraph, app, reason): def on_change_input(self, fgraph, app, i, old_r, new_r, reason): """ - app.inputs[i] changed from old_r to new_r. + node.inputs[i] changed from old_r to new_r. """ - if app == "output": + if isinstance(app.op, Output): # app == 'output' is special key that means FunctionGraph is redefining which nodes are being # considered 'outputs' of the graph. pass diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index 7453a26aee..b79bcefd0f 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -3,7 +3,7 @@ import time from collections import OrderedDict from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Any, Literal, Union, cast +from typing import Any, Union, cast import pytensor from pytensor.configdefaults import config @@ -19,15 +19,30 @@ ) from pytensor.graph.basic import as_string as graph_as_string from pytensor.graph.features import AlreadyThere, Feature, ReplaceValidate +from pytensor.graph.op import Op from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError from pytensor.misc.ordered_set import OrderedSet -if TYPE_CHECKING: - from pytensor.graph.op import Op +ClientType = tuple[Apply, int] -ApplyOrOutput = Apply | Literal["output"] -ClientType = tuple[ApplyOrOutput, int] + +class Output(Op): + """A dummy `Op` that represents an output variable in a `FunctionGraph`.""" + + __props__ = ("idx",) + + def __init__(self, idx): + self.idx = idx + + def make_node(self, inp): + return Apply(self, [inp], []) + + def perform(self, node, inputs, outputs): + raise RuntimeError("Output Ops should never be evaluated") + + def __str__(self): + return f"output[{self.idx}]" class FunctionGraph(MetaObject): @@ -157,7 +172,7 @@ def add_output( """Add a new variable as an output to this `FunctionGraph`.""" self.outputs.append(var) self.import_var(var, reason=reason, import_missing=import_missing) - self.clients[var].append(("output", len(self.outputs) - 1)) + self.clients[var].append((Output(len(self.outputs) - 1).make_node(var), 0)) def add_input(self, var: Variable, check: bool = True) -> None: """Add a new variable as an input to this `FunctionGraph`. @@ -198,10 +213,8 @@ def add_client(self, var: Variable, new_client: ClientType) -> None: A ``(node, i)`` pair such that ``node.inputs[i]`` is `var`. """ - if not isinstance(new_client[0], Apply) and new_client[0] != "output": - raise TypeError( - 'The first entry of `new_client` must be an `Apply` node or the string `"output"`' - ) + if not isinstance(new_client[0], Apply): + raise TypeError("The first entry of `new_client` must be an `Apply` node") self.clients[var].append(new_client) def remove_client( @@ -382,7 +395,7 @@ def import_node( def change_node_input( self, - node: ApplyOrOutput, + node: Apply, i: int, new_var: Variable, reason: str | None = None, @@ -401,9 +414,7 @@ def change_node_input( Parameters ---------- node - The node for which an input is to be changed. If the value is - the string ``"output"`` then the ``self.outputs`` will be used - instead of ``node.inputs``. + The node for which an input is to be changed. i The index in `node.inputs` that we want to change. new_var @@ -417,23 +428,16 @@ def change_node_input( narrowed and would otherwise fail this check. """ # TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?) - if node == "output": - r = self.outputs[i] - if check and not r.type.is_super(new_var.type): - raise TypeError( - f"The type of the replacement ({new_var.type}) must be " - f"compatible with the type of the original Variable ({r.type})." - ) + r = node.inputs[i] + if check and not r.type.is_super(new_var.type): + raise TypeError( + f"The type of the replacement ({new_var.type}) must be " + f"compatible with the type of the original Variable ({r.type})." + ) + node.inputs[i] = new_var + + if isinstance(node.op, Output): self.outputs[i] = new_var - else: - assert isinstance(node, Apply) - r = node.inputs[i] - if check and not r.type.is_super(new_var.type): - raise TypeError( - f"The type of the replacement ({new_var.type}) must be " - f"compatible with the type of the original Variable ({r.type})." - ) - node.inputs[i] = new_var if r is new_var: return @@ -521,7 +525,7 @@ def replace_all(self, pairs: Iterable[tuple[Variable, Variable]], **kwargs) -> N def _remove_output(self, idx: int): """Remove the output at index `idx` and update the indices in the clients entries. - `FunctionGraph.clients` contains entries like ``("output", i)`` under + `FunctionGraph.clients` contains entries like ``(output(var), i)`` under each output variable in `FunctionGraph.outputs`. The ``i`` values correspond to each output's location within the `FunctionGraph.outputs` list, so, when an output is removed from the graph, all these entries @@ -533,16 +537,23 @@ def _remove_output(self, idx: int): which they're contained are already being updated in-place. """ old_idx_mappings = tuple((out, i) for i, out in enumerate(self.outputs)) + self.outputs.pop(idx) new_idx = 0 for out, old_idx in old_idx_mappings: if old_idx == idx: continue - out_clients = self.clients[out] - arrow: ClientType = ("output", old_idx) - arrow_idx = out_clients.index(arrow) - out_clients[arrow_idx] = ("output", new_idx) + + if old_idx != new_idx: + out_clients = self.clients[out] + [client_out_idx] = [ + i + for i, (out_client, _) in enumerate(out_clients) + if isinstance(out_client.op, Output) + and out_client.op.idx == old_idx + ] + out_clients[client_out_idx] = (Output(new_idx).make_node(out), 0) new_idx += 1 def remove_node(self, node: Apply, reason: str | None = None): @@ -570,8 +581,8 @@ def remove_node(self, node: Apply, reason: str | None = None): while out_clients: out_client, out_idx = out_clients.pop() - if out_client == "output": - self._remove_output(out_idx) + if isinstance(out_client.op, Output): + self._remove_output(out_client.op.idx) # TODO: We could short-circuit all of the graph walking and # clear everything at once when all the outputs are gone. @@ -630,32 +641,26 @@ def remove_node(self, node: Apply, reason: str | None = None): self.execute_callbacks("on_prune", node, reason) def remove_input(self, input_idx: int, reason: str | None = None): - """Remove the input at index `input_idx`.""" + """Remove the input at index `input_idx`. + + Any node that depended on such input will also be removed. + """ var = self.inputs.pop(input_idx) for client, idx in list(self.clients[var]): - if client == "output": - out_var = self.outputs[idx] - out_node = out_var.owner - if out_node is None: - assert out_var in self.inputs - self.outputs.pop(idx) - continue - client_node = out_node - else: - assert isinstance(client, Apply) - client_node = client - - self.remove_node(client_node, reason=reason) + self.remove_node(client, reason=reason) def remove_output(self, output_idx: int, reason: str | None = None): """Remove the output at index `input_idx`.""" var = self.outputs[output_idx] self._remove_output(output_idx) - self.remove_client( - var, ("output", output_idx), reason=reason, remove_if_empty=True + old_out_client = next( + (client, i) + for client, i in self.clients[var] + if isinstance(client.op, Output) and client.op.idx == output_idx ) + self.remove_client(var, old_out_client, reason=reason, remove_if_empty=True) def attach_feature(self, feature: Feature) -> None: """Add a ``graph.features.Feature`` to this function graph and trigger its ``on_attach`` callback.""" @@ -832,19 +837,17 @@ def check_integrity(self) -> None: ): raise Exception(f"Undeclared input: {variable}") for cl_node, i in self.clients[variable]: - if cl_node == "output": - if self.outputs[i] is not variable: + if isinstance(cl_node.op, Output): + out_idx = cl_node.op.idx + if self.outputs[out_idx] is not variable: raise Exception( - f"Inconsistent clients list: {variable}, {self.outputs[i]}" + f"Inconsistent clients list: {variable}, {self.outputs[out_idx]}" ) - continue - - assert isinstance(cl_node, Apply) - - if cl_node not in nodes: + elif cl_node not in nodes: raise Exception( f"Client not in FunctionGraph: {variable}, {(cl_node, i)}" ) + if cl_node.inputs[i] is not variable: raise Exception( f"Inconsistent clients list: {variable}, {cl_node.inputs[i]}" diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 584dccfba3..b691961af8 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -30,7 +30,7 @@ vars_between, ) from pytensor.graph.features import AlreadyThere, Feature, NodeFinder -from pytensor.graph.fg import FunctionGraph +from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.op import Op from pytensor.graph.utils import AssocList, InconsistencyError from pytensor.misc.ordered_set import OrderedSet @@ -738,7 +738,7 @@ def apply(self, fgraph): if any( i in flatten(c.op.destroy_map.values()) for c, i in clients - if c != "output" and c.op.destroy_map + if c.op.destroy_map ): continue @@ -1626,8 +1626,6 @@ def transform(self, fgraph, node, get_nodes=True): if get_nodes and self.get_nodes is not None: for real_node in self.get_nodes(fgraph, node): - if real_node == "output": - continue ret = self.transform(fgraph, real_node, get_nodes=False) if ret is not False and ret is not None: return dict(zip(real_node.outputs, ret)) @@ -2407,7 +2405,7 @@ def importer(node): if self.tracks_on_change_inputs: def chin_(node, i, r, new_r, reason): - if node is not current_node and not isinstance(node, str): + if node is not current_node and not isinstance(node.op, Output): q.append(node) chin = chin_ diff --git a/pytensor/graph/rewriting/utils.py b/pytensor/graph/rewriting/utils.py index 685a18879b..c8138396a9 100644 --- a/pytensor/graph/rewriting/utils.py +++ b/pytensor/graph/rewriting/utils.py @@ -10,7 +10,7 @@ graph_inputs, vars_between, ) -from pytensor.graph.fg import FunctionGraph +from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.rewriting.db import RewriteDatabaseQuery @@ -230,7 +230,7 @@ def get_clients_at_depth( for var in node.outputs: if depth > 0: for out_node, _ in fgraph.clients[var]: - if out_node == "output": + if isinstance(out_node.op, Output): continue yield from get_clients_at_depth( fgraph, cast(Apply, out_node), depth - 1 diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py index e11247c9b3..0beb468c9b 100644 --- a/pytensor/link/c/basic.py +++ b/pytensor/link/c/basic.py @@ -354,9 +354,7 @@ def get_c_declare(fgraph, r, name, sub): # it means they need `r`'s dtype to be declared, so # we have to pass `check_input=True` to `c_declare`. if any( - getattr(c.op, "check_input", config.check_input) - for (c, _) in fgraph.clients[r] - if not isinstance(c, str) + getattr(c.op, "check_input", config.check_input) for (c, _) in fgraph.clients[r] ) or (r.owner and getattr(r.owner.op, "check_input", config.check_input)): c_declare = r.type.c_declare(name, sub, True) else: diff --git a/pytensor/link/vm.py b/pytensor/link/vm.py index 1cb3c01265..159dc219bd 100644 --- a/pytensor/link/vm.py +++ b/pytensor/link/vm.py @@ -956,8 +956,7 @@ def compute_gc_dependencies(self, variables): if k.owner and self.fgraph.clients[k]: ls = [] for cl in self.fgraph.clients[k]: - if cl[0] != "output": - ls += cl[0].outputs + ls += cl[0].outputs dependencies[k] += ls return dependencies diff --git a/pytensor/printing.py b/pytensor/printing.py index 80863a720c..120ad54e24 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -425,10 +425,15 @@ def debugprint( for out in inner_outputs: if ( - isinstance(getattr(out.owner, "op", None), HasInnerGraph) - or hasattr(getattr(out.owner, "op", None), "scalar_op") - and isinstance(out.owner.op.scalar_op, HasInnerGraph) - ) and out not in inner_graph_vars: + out.owner is not None + and ( + isinstance(out.owner.op, HasInnerGraph) + or isinstance( + getattr(out.owner.op, "scalar_op", None), HasInnerGraph + ) + ) + and out not in inner_graph_vars + ): inner_graph_vars.append(out) _debugprint( diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index ae128c608f..2593e397cf 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -27,7 +27,7 @@ ) from pytensor.graph.destroyhandler import DestroyHandler from pytensor.graph.features import ReplaceValidate -from pytensor.graph.fg import FunctionGraph +from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.op import compute_test_value from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import ( @@ -1303,7 +1303,7 @@ def save_mem_new_scan(fgraph, node): for cl, _ in fgraph.clients[out]: # 2.1 outputs of the function # => output needs all its intermediate values - if isinstance(cl, str): + if isinstance(cl.op, Output): # if the node is actually an output, then # we need to store the entire thing global_nsteps = None @@ -1412,7 +1412,7 @@ def save_mem_new_scan(fgraph, node): for i, out in enumerate(node.outputs[:c_outs]): # look at all its clients for cl, _ in fgraph.clients[out]: - if isinstance(cl, str): + if isinstance(cl.op, Output): store_steps[i] = 0 break elif not isinstance(cl.op, Subtensor): @@ -2309,7 +2309,7 @@ def push_out_dot1_scan(fgraph, node): and isinstance(out.owner.op.scalar_op, ps.Add) and inp in out.owner.inputs and len(fgraph.clients[outer_out]) == 1 - and not isinstance(fgraph.clients[outer_out][0][0], str) + and not isinstance(fgraph.clients[outer_out][0][0], Output) and isinstance(fgraph.clients[outer_out][0][0].op, Subtensor) and fgraph.clients[outer_out][0][0].op.idx_list == (-1,) ): diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 518b55da99..56b3f98c80 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -24,7 +24,7 @@ from pytensor.gradient import DisconnectedType, grad_undefined from pytensor.graph import RewriteDatabaseQuery from pytensor.graph.basic import Apply, Constant, Variable, equal_computations -from pytensor.graph.fg import FunctionGraph +from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.op import Op from pytensor.graph.replace import _vectorize_node from pytensor.graph.rewriting.db import EquilibriumDB @@ -1654,7 +1654,7 @@ def do_constant_folding(self, fgraph, node): return False for client, idx in clients: - if client == "output": + if isinstance(client.op, Output): # If the output is a constant, it will have to be deepcopied # each time the function is called. So we do not fold. return False diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index b1960927e6..eb4990c9c9 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -3,6 +3,7 @@ from pytensor.compile import optdb from pytensor.configdefaults import config from pytensor.graph import ancestors +from pytensor.graph.fg import Output from pytensor.graph.op import compute_test_value from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter from pytensor.scalar import integer_types @@ -32,15 +33,12 @@ def is_rv_used_in_graph(base_rv, node, fgraph): TODO: We should apply all the shape rewrites before these rewrites, since that would properly remove the unnecessary dependencies on `base_rv` (when possible). - """ - - def _node_check(n, i): - if n == "output": - n = fgraph.outputs[i].owner - return n == node or isinstance(n.op, Shape | Shape_i) - - return not all(_node_check(n, i) for n, i in fgraph.clients.get(base_rv, ())) + return any( + n + for n, i in fgraph.clients.get(base_rv, ()) + if not isinstance(n.op, Shape | Shape_i | Output) + ) @node_rewriter([RandomVariable], inplace=True) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index ac4918117e..bc668dbc95 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -14,7 +14,7 @@ from pytensor.graph import FunctionGraph from pytensor.graph.basic import Apply, Constant, Variable, ancestors, io_toposort from pytensor.graph.features import ReplaceValidate -from pytensor.graph.fg import ApplyOrOutput +from pytensor.graph.fg import Output from pytensor.graph.rewriting.basic import ( EquilibriumGraphRewriter, GraphRewriter, @@ -688,7 +688,7 @@ def find_next_fuseable_subgraph( """ FUSEABLE_MAPPING = defaultdict[Variable, list[Apply]] - UNFUSEABLE_MAPPING = defaultdict[Variable, set[ApplyOrOutput]] + UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] def initialize_fuseable_mappings( *, fg: FunctionGraph @@ -729,7 +729,6 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool: for client, _ in clients: if ( out_maybe_fuseable - and not isinstance(client, str) # "output" and isinstance(client.op, Elemwise) # and not isinstance(client.op.scalar_op, ps.Composite) and len(client.outputs) == 1 @@ -843,7 +842,7 @@ def variables_depend_on( implied_unfuseable_clients = { c for client in unfuseable_clients_clone.get(next_out, ()) - if not isinstance(client, str) # "output" + if not isinstance(client.op, Output) for c in client.outputs } diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index cdb1e59101..7ae87c70e0 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -291,8 +291,6 @@ def local_det_chol(fgraph, node): """ (x,) = node.inputs for cl, xpos in fgraph.clients[x]: - if cl == "output": - continue if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky): L = cl.outputs[0] return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)] diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 06d023d780..1ea94f4f69 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1137,14 +1137,11 @@ def transform(self, fgraph, node): # this canonized graph... if so, we do nothing and wait for # them to be transformed. for c, c_idx in out_clients: - if c == "output": - continue while ( - isinstance(getattr(c, "op", None), DimShuffle) - and len(fgraph.clients[c.outputs[0]]) <= 1 + isinstance(c.op, DimShuffle) and len(fgraph.clients[c.outputs[0]]) <= 1 ): c = fgraph.clients[c.outputs[0]][0][0] - if getattr(c, "op", "") in [self.main, self.inverse, self.reciprocal]: + if c.op in [self.main, self.inverse, self.reciprocal]: return False # Here we make the canonical version of the graph around this node diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 2ec1afa930..bdd5b8003a 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -401,7 +401,7 @@ def update_shape(self, r, other_r): merged_shape.append(other_shape[i]) elif ( ps.owner - and isinstance(getattr(ps.owner, "op", None), Shape_i) + and isinstance(ps.owner.op, Shape_i) and ps.owner.op.i == i and ps.owner.inputs[0] in (r, other_r) ): @@ -602,7 +602,7 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason): # r is *scheduled*. # At that point, node is no longer a client of r, but of new_r for shpnode, idx in fgraph.clients[r] + [(node, i)]: - if isinstance(getattr(shpnode, "op", None), Shape_i): + if isinstance(shpnode.Op, Shape_i): idx = shpnode.op.i repl = self.shape_of[new_r][idx] if repl.owner is shpnode: @@ -1028,7 +1028,10 @@ def local_Shape_of_SpecifyShape(fgraph, node): specified_shape = node.inputs[0] - if not isinstance(getattr(specified_shape.owner, "op", None), SpecifyShape): + if not ( + specified_shape.owner is not None + and isinstance(specified_shape.owner.op, SpecifyShape) + ): return False x, *shape = specified_shape.owner.inputs diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index 1d2af0c7f0..05e02148e2 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -6,7 +6,7 @@ from pytensor.configdefaults import config from pytensor.graph.basic import NominalVariable -from pytensor.graph.fg import FunctionGraph +from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.utils import MissingInputError from tests.graph.utils import ( MyConstant, @@ -77,8 +77,13 @@ def test_init(self): assert fg.variables == {var1, var2, var3, var4} assert fg.get_clients(var1) == [(var3.owner, 0)] assert fg.get_clients(var2) == [(var4.owner, 1)] - assert fg.get_clients(var3) == [("output", 0), (var4.owner, 0)] - assert fg.get_clients(var4) == [("output", 1)] + var3_clients = fg.get_clients(var3) + assert len(var3_clients) == 2 + assert var3_clients[0][0].op == Output(0) + assert var3_clients[1] == (var4.owner, 0) + var4_clients = fg.get_clients(var4) + assert len(var4_clients) == 1 + assert var4_clients[0][0].op == Output(1) varC = MyConstant("varC") var5 = op1(var1, varC) @@ -207,8 +212,11 @@ def test_change_input(self): fg = FunctionGraph([var1, var2], [var3, var5], clone=False) var6 = MyVariable2("var6") + [out_client] = [ + cl for cl, _ in fg.clients[fg.outputs[0]] if isinstance(cl.op, Output) + ] with pytest.raises(TypeError): - fg.change_node_input("output", 1, var6) + fg.change_node_input(out_client, 0, var6) with pytest.raises(TypeError): fg.change_node_input(var5.owner, 1, var6) @@ -357,12 +365,13 @@ def test_check_integrity(self): # TODO: What if the index value is greater than 1? It will throw an # `IndexError`, but that doesn't sound like anything we'd want. + out_node = Output(idx=1).make_node(var4) with pytest.raises(Exception, match="Inconsistent clients list.*"): - fg.add_client(var4, ("output", 1)) + fg.add_client(var4, (out_node, 0)) fg.check_integrity() - fg.remove_client(var4, ("output", 1)) + fg.remove_client(var4, (out_node, 0)) with pytest.raises(TypeError, match="The first entry of.*"): fg.add_client(var4, (None, 0))