From 55632f14f216b8f6b668924760771de2dbf45f8f Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Mon, 9 Jan 2023 20:37:43 -0600 Subject: [PATCH] Fix handling of constant inputs to SubsumingElemwise --- aemcmc/basic.py | 2 +- aemcmc/rewriting.py | 32 ++++++++++++++++++++++++-------- tests/test_rewriting.py | 28 +++++++++++++++++++++++++++- 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/aemcmc/basic.py b/aemcmc/basic.py index d69fbf8..2a568a5 100644 --- a/aemcmc/basic.py +++ b/aemcmc/basic.py @@ -127,6 +127,6 @@ def construct_sampler( if rv not in obs_rvs_to_values } - return Sampler(sampling_steps, updates, parameters), { + return Sampler(sampling_steps, posterior_updates, parameters), { new_to_old_rvs[rv]: init_var for rv, init_var in rvs_to_init_vals.items() } diff --git a/aemcmc/rewriting.py b/aemcmc/rewriting.py index 9039440..470cc3c 100644 --- a/aemcmc/rewriting.py +++ b/aemcmc/rewriting.py @@ -5,7 +5,7 @@ from aeppl.rewriting import MeasurableConversionTracker from aesara.compile.builders import OpFromGraph from aesara.compile.mode import optdb -from aesara.graph.basic import Apply, Variable, clone_replace, io_toposort +from aesara.graph.basic import Apply, Constant, Variable, clone_replace, io_toposort from aesara.graph.features import AlreadyThere, Feature from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op @@ -50,7 +50,7 @@ def construct_ir_fgraph( - obs_rvs_to_values: Dict[Variable, Variable] + obs_rvs_to_values: Dict[Variable, Variable], clone=True ) -> Tuple[ FunctionGraph, Dict[Variable, Variable], @@ -80,7 +80,7 @@ def construct_ir_fgraph( fgraph = FunctionGraph( outputs=rv_outputs, - clone=True, + clone=clone, memo=memo, copy_orphans=False, copy_inputs=False, @@ -88,7 +88,7 @@ def construct_ir_fgraph( ) # Update `obs_rvs_to_values` so that it uses the new cloned variables - obs_rvs_to_values = {memo[k]: v for k, v in obs_rvs_to_values.items()} + obs_rvs_to_values = {memo.get(k, k): v for k, v in obs_rvs_to_values.items()} sampler_ir_db.query("+basic").rewrite(fgraph) @@ -186,10 +186,28 @@ def __init__(self, inputs, outputs, *args, **kwargs): # self.destroy_map = self.elemwise_op.destroy_map self.ufunc = None self.nfunc = None - OpFromGraph.__init__(self, inputs, outputs, *args, **kwargs) + + used_inputs = [inp for inp in inputs if not isinstance(inp, Constant)] + + OpFromGraph.__init__(self, used_inputs, outputs, *args, **kwargs) def make_node(self, *inputs): - node = super().make_node(*inputs) + # Remove constants + used_inputs = [inp for inp in inputs if not isinstance(inp, Constant)] + + # TODO: We could make sure that the new constant inputs correspond to + # the originals... + + # The user interface doesn't expect the shared variable inputs of the + # inner-graph, but, since `Op.make_node` does (and `Op.__call__` + # dispatches to `Op.make_node`), we need to compensate here + num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs) + + if len(used_inputs) == num_expected_inps: + used_inputs = used_inputs + self.shared_inputs + + node = super().make_node(*used_inputs) + # Remove shared variable inputs. We aren't going to compute anything # with this `Op`, so they're not needed real_inputs = node.inputs[: len(node.inputs) - len(self.shared_inputs)] @@ -346,8 +364,6 @@ def local_elemwise_dimshuffle_subsume(fgraph, node): new_out = new_op(*new_inputs) - assert len(new_out.owner.inputs) == len(node.inputs) - return new_out.owner.outputs diff --git a/tests/test_rewriting.py b/tests/test_rewriting.py index ef07ba4..41b0795 100644 --- a/tests/test_rewriting.py +++ b/tests/test_rewriting.py @@ -6,7 +6,11 @@ from etuples import etuple, etuplize from unification import unify -from aemcmc.rewriting import SubsumingElemwise, local_elemwise_dimshuffle_subsume +from aemcmc.rewriting import ( + SubsumingElemwise, + construct_ir_fgraph, + local_elemwise_dimshuffle_subsume, +) def test_SubsumingElemwise_basics(): @@ -107,3 +111,25 @@ def test_local_elemwise_dimshuffle_subsume_transpose(): # The input corresponding to `b`/`b_ds` should be equivalent to `b.T` assert isinstance(res.owner.inputs[1].owner.op, DimShuffle) assert equal_computations([b.T], [res.owner.inputs[1]]) + + +def test_SubsumingElemwise_constant_inputs(): + """Make sure constant inputs are handled correctly by `SubsumingElemwise`.""" + + srng = at.random.RandomStream(0) + + s = at.lscalar("s") + # The `1` is the constant input to a `true_div` `Elemwise` that should be + # "subsumed" + Z = srng.exponential(1, size=s, name="Z") + mu = 1 / Z + Y = srng.normal(mu, name="Y") + y = Y.clone() + y.name = "y" + + res, *_ = construct_ir_fgraph({Y: y}, clone=False) + + normal_node = res.outputs[1].owner + subelem_node = normal_node.inputs[3].owner + assert isinstance(subelem_node.op, SubsumingElemwise) + assert subelem_node.inputs == [Z]