Skip to content

Commit

Permalink
Fix handling of constant inputs to SubsumingElemwise
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jan 10, 2023
1 parent 6fe3c54 commit 55632f1
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 10 deletions.
2 changes: 1 addition & 1 deletion aemcmc/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,6 @@ def construct_sampler(
if rv not in obs_rvs_to_values
}

return Sampler(sampling_steps, updates, parameters), {
return Sampler(sampling_steps, posterior_updates, parameters), {
new_to_old_rvs[rv]: init_var for rv, init_var in rvs_to_init_vals.items()
}
32 changes: 24 additions & 8 deletions aemcmc/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aeppl.rewriting import MeasurableConversionTracker
from aesara.compile.builders import OpFromGraph
from aesara.compile.mode import optdb
from aesara.graph.basic import Apply, Variable, clone_replace, io_toposort
from aesara.graph.basic import Apply, Constant, Variable, clone_replace, io_toposort
from aesara.graph.features import AlreadyThere, Feature
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
Expand Down Expand Up @@ -50,7 +50,7 @@


def construct_ir_fgraph(
obs_rvs_to_values: Dict[Variable, Variable]
obs_rvs_to_values: Dict[Variable, Variable], clone=True
) -> Tuple[
FunctionGraph,
Dict[Variable, Variable],
Expand Down Expand Up @@ -80,15 +80,15 @@ def construct_ir_fgraph(

fgraph = FunctionGraph(
outputs=rv_outputs,
clone=True,
clone=clone,
memo=memo,
copy_orphans=False,
copy_inputs=False,
features=[ShapeFeature(), MeasurableConversionTracker()],
)

# Update `obs_rvs_to_values` so that it uses the new cloned variables
obs_rvs_to_values = {memo[k]: v for k, v in obs_rvs_to_values.items()}
obs_rvs_to_values = {memo.get(k, k): v for k, v in obs_rvs_to_values.items()}

sampler_ir_db.query("+basic").rewrite(fgraph)

Expand Down Expand Up @@ -186,10 +186,28 @@ def __init__(self, inputs, outputs, *args, **kwargs):
# self.destroy_map = self.elemwise_op.destroy_map
self.ufunc = None
self.nfunc = None
OpFromGraph.__init__(self, inputs, outputs, *args, **kwargs)

used_inputs = [inp for inp in inputs if not isinstance(inp, Constant)]

OpFromGraph.__init__(self, used_inputs, outputs, *args, **kwargs)

def make_node(self, *inputs):
node = super().make_node(*inputs)
# Remove constants
used_inputs = [inp for inp in inputs if not isinstance(inp, Constant)]

# TODO: We could make sure that the new constant inputs correspond to
# the originals...

# The user interface doesn't expect the shared variable inputs of the
# inner-graph, but, since `Op.make_node` does (and `Op.__call__`
# dispatches to `Op.make_node`), we need to compensate here
num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs)

if len(used_inputs) == num_expected_inps:
used_inputs = used_inputs + self.shared_inputs

node = super().make_node(*used_inputs)

# Remove shared variable inputs. We aren't going to compute anything
# with this `Op`, so they're not needed
real_inputs = node.inputs[: len(node.inputs) - len(self.shared_inputs)]
Expand Down Expand Up @@ -346,8 +364,6 @@ def local_elemwise_dimshuffle_subsume(fgraph, node):

new_out = new_op(*new_inputs)

assert len(new_out.owner.inputs) == len(node.inputs)

return new_out.owner.outputs


Expand Down
28 changes: 27 additions & 1 deletion tests/test_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from etuples import etuple, etuplize
from unification import unify

from aemcmc.rewriting import SubsumingElemwise, local_elemwise_dimshuffle_subsume
from aemcmc.rewriting import (
SubsumingElemwise,
construct_ir_fgraph,
local_elemwise_dimshuffle_subsume,
)


def test_SubsumingElemwise_basics():
Expand Down Expand Up @@ -107,3 +111,25 @@ def test_local_elemwise_dimshuffle_subsume_transpose():
# The input corresponding to `b`/`b_ds` should be equivalent to `b.T`
assert isinstance(res.owner.inputs[1].owner.op, DimShuffle)
assert equal_computations([b.T], [res.owner.inputs[1]])


def test_SubsumingElemwise_constant_inputs():
"""Make sure constant inputs are handled correctly by `SubsumingElemwise`."""

srng = at.random.RandomStream(0)

s = at.lscalar("s")
# The `1` is the constant input to a `true_div` `Elemwise` that should be
# "subsumed"
Z = srng.exponential(1, size=s, name="Z")
mu = 1 / Z
Y = srng.normal(mu, name="Y")
y = Y.clone()
y.name = "y"

res, *_ = construct_ir_fgraph({Y: y}, clone=False)

normal_node = res.outputs[1].owner
subelem_node = normal_node.inputs[3].owner
assert isinstance(subelem_node.op, SubsumingElemwise)
assert subelem_node.inputs == [Z]

0 comments on commit 55632f1

Please sign in to comment.