Skip to content

Commit

Permalink
Fix handling of constant inputs to SubsumingElemwise
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jan 10, 2023
1 parent 6fe3c54 commit 37a4cf3
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 7 deletions.
2 changes: 1 addition & 1 deletion aemcmc/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,6 @@ def construct_sampler(
if rv not in obs_rvs_to_values
}

return Sampler(sampling_steps, updates, parameters), {
return Sampler(sampling_steps, posterior_updates, parameters), {
new_to_old_rvs[rv]: init_var for rv, init_var in rvs_to_init_vals.items()
}
26 changes: 21 additions & 5 deletions aemcmc/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aeppl.rewriting import MeasurableConversionTracker
from aesara.compile.builders import OpFromGraph
from aesara.compile.mode import optdb
from aesara.graph.basic import Apply, Variable, clone_replace, io_toposort
from aesara.graph.basic import Apply, Constant, Variable, clone_replace, io_toposort
from aesara.graph.features import AlreadyThere, Feature
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
Expand Down Expand Up @@ -186,10 +186,28 @@ def __init__(self, inputs, outputs, *args, **kwargs):
# self.destroy_map = self.elemwise_op.destroy_map
self.ufunc = None
self.nfunc = None
OpFromGraph.__init__(self, inputs, outputs, *args, **kwargs)

used_inputs = [inp for inp in inputs if not isinstance(inp, Constant)]

OpFromGraph.__init__(self, used_inputs, outputs, *args, **kwargs)

def make_node(self, *inputs):
node = super().make_node(*inputs)
# Remove constants
used_inputs = [inp for inp in inputs if not isinstance(inp, Constant)]

# TODO: We could make sure that the new constant inputs correspond to
# the originals...

# The user interface doesn't expect the shared variable inputs of the
# inner-graph, but, since `Op.make_node` does (and `Op.__call__`
# dispatches to `Op.make_node`), we need to compensate here
num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs)

if len(used_inputs) == num_expected_inps:
used_inputs = used_inputs + self.shared_inputs

node = super().make_node(*used_inputs)

# Remove shared variable inputs. We aren't going to compute anything
# with this `Op`, so they're not needed
real_inputs = node.inputs[: len(node.inputs) - len(self.shared_inputs)]
Expand Down Expand Up @@ -346,8 +364,6 @@ def local_elemwise_dimshuffle_subsume(fgraph, node):

new_out = new_op(*new_inputs)

assert len(new_out.owner.inputs) == len(node.inputs)

return new_out.owner.outputs


Expand Down
26 changes: 25 additions & 1 deletion tests/test_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from etuples import etuple, etuplize
from unification import unify

from aemcmc.rewriting import SubsumingElemwise, local_elemwise_dimshuffle_subsume
from aemcmc.rewriting import (
SubsumingElemwise,
construct_ir_fgraph,
local_elemwise_dimshuffle_subsume,
)


def test_SubsumingElemwise_basics():
Expand Down Expand Up @@ -107,3 +111,23 @@ def test_local_elemwise_dimshuffle_subsume_transpose():
# The input corresponding to `b`/`b_ds` should be equivalent to `b.T`
assert isinstance(res.owner.inputs[1].owner.op, DimShuffle)
assert equal_computations([b.T], [res.owner.inputs[1]])


def test_SubsumingElemwise_constant_inputs():
"""Make sure constant inputs are handled correctly by `SubsumingElemwise`."""

srng = at.random.RandomStream(0)

s = at.lscalar("s")
# The `1` is the constant input to a `true_div` `Elemwise` that should be
# "subsumed"
mu = 1 / srng.exponential(1, size=s, name="Z")
Y = srng.normal(mu, name="Y")
y = Y.clone()
y.name = "y"

res, *_ = construct_ir_fgraph({Y: y})

normal_node = res.outputs[1].owner
subelem_node = normal_node.inputs[3].owner
assert isinstance(subelem_node.op, SubsumingElemwise)

0 comments on commit 37a4cf3

Please sign in to comment.