diff --git a/aemcmc/basic.py b/aemcmc/basic.py index d69fbf8..2a568a5 100644 --- a/aemcmc/basic.py +++ b/aemcmc/basic.py @@ -127,6 +127,6 @@ def construct_sampler( if rv not in obs_rvs_to_values } - return Sampler(sampling_steps, updates, parameters), { + return Sampler(sampling_steps, posterior_updates, parameters), { new_to_old_rvs[rv]: init_var for rv, init_var in rvs_to_init_vals.items() } diff --git a/aemcmc/rewriting.py b/aemcmc/rewriting.py index 9039440..81befe5 100644 --- a/aemcmc/rewriting.py +++ b/aemcmc/rewriting.py @@ -5,7 +5,7 @@ from aeppl.rewriting import MeasurableConversionTracker from aesara.compile.builders import OpFromGraph from aesara.compile.mode import optdb -from aesara.graph.basic import Apply, Variable, clone_replace, io_toposort +from aesara.graph.basic import Apply, Constant, Variable, clone_replace, io_toposort from aesara.graph.features import AlreadyThere, Feature from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op @@ -186,10 +186,28 @@ def __init__(self, inputs, outputs, *args, **kwargs): # self.destroy_map = self.elemwise_op.destroy_map self.ufunc = None self.nfunc = None - OpFromGraph.__init__(self, inputs, outputs, *args, **kwargs) + + used_inputs = [inp for inp in inputs if not isinstance(inp, Constant)] + + OpFromGraph.__init__(self, used_inputs, outputs, *args, **kwargs) def make_node(self, *inputs): - node = super().make_node(*inputs) + # Remove constants + used_inputs = [inp for inp in inputs if not isinstance(inp, Constant)] + + # TODO: We could make sure that the new constant inputs correspond to + # the originals... + + # The user interface doesn't expect the shared variable inputs of the + # inner-graph, but, since `Op.make_node` does (and `Op.__call__` + # dispatches to `Op.make_node`), we need to compensate here + num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs) + + if len(used_inputs) == num_expected_inps: + used_inputs = used_inputs + self.shared_inputs + + node = super().make_node(*used_inputs) + # Remove shared variable inputs. We aren't going to compute anything # with this `Op`, so they're not needed real_inputs = node.inputs[: len(node.inputs) - len(self.shared_inputs)] @@ -346,8 +364,6 @@ def local_elemwise_dimshuffle_subsume(fgraph, node): new_out = new_op(*new_inputs) - assert len(new_out.owner.inputs) == len(node.inputs) - return new_out.owner.outputs diff --git a/tests/test_rewriting.py b/tests/test_rewriting.py index ef07ba4..f616224 100644 --- a/tests/test_rewriting.py +++ b/tests/test_rewriting.py @@ -6,7 +6,11 @@ from etuples import etuple, etuplize from unification import unify -from aemcmc.rewriting import SubsumingElemwise, local_elemwise_dimshuffle_subsume +from aemcmc.rewriting import ( + SubsumingElemwise, + construct_ir_fgraph, + local_elemwise_dimshuffle_subsume, +) def test_SubsumingElemwise_basics(): @@ -107,3 +111,23 @@ def test_local_elemwise_dimshuffle_subsume_transpose(): # The input corresponding to `b`/`b_ds` should be equivalent to `b.T` assert isinstance(res.owner.inputs[1].owner.op, DimShuffle) assert equal_computations([b.T], [res.owner.inputs[1]]) + + +def test_SubsumingElemwise_constant_inputs(): + """Make sure constant inputs are handled correctly by `SubsumingElemwise`.""" + + srng = at.random.RandomStream(0) + + s = at.lscalar("s") + # The `1` is the constant input to a `true_div` `Elemwise` that should be + # "subsumed" + mu = 1 / srng.exponential(1, size=s, name="Z") + Y = srng.normal(mu, name="Y") + y = Y.clone() + y.name = "y" + + res, *_ = construct_ir_fgraph({Y: y}) + + normal_node = res.outputs[1].owner + subelem_node = normal_node.inputs[3].owner + assert isinstance(subelem_node.op, SubsumingElemwise)