Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of constant inputs to SubsumingElemwise #93

Merged
merged 1 commit into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aemcmc/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,6 @@ def construct_sampler(
if rv not in obs_rvs_to_values
}

return Sampler(sampling_steps, updates, parameters), {
return Sampler(sampling_steps, posterior_updates, parameters), {
new_to_old_rvs[rv]: init_var for rv, init_var in rvs_to_init_vals.items()
}
32 changes: 24 additions & 8 deletions aemcmc/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aeppl.rewriting import MeasurableConversionTracker
from aesara.compile.builders import OpFromGraph
from aesara.compile.mode import optdb
from aesara.graph.basic import Apply, Variable, clone_replace, io_toposort
from aesara.graph.basic import Apply, Constant, Variable, clone_replace, io_toposort
from aesara.graph.features import AlreadyThere, Feature
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
Expand Down Expand Up @@ -50,7 +50,7 @@


def construct_ir_fgraph(
obs_rvs_to_values: Dict[Variable, Variable]
obs_rvs_to_values: Dict[Variable, Variable], clone=True
) -> Tuple[
FunctionGraph,
Dict[Variable, Variable],
Expand Down Expand Up @@ -80,15 +80,15 @@ def construct_ir_fgraph(

fgraph = FunctionGraph(
outputs=rv_outputs,
clone=True,
clone=clone,
memo=memo,
copy_orphans=False,
copy_inputs=False,
features=[ShapeFeature(), MeasurableConversionTracker()],
)

# Update `obs_rvs_to_values` so that it uses the new cloned variables
obs_rvs_to_values = {memo[k]: v for k, v in obs_rvs_to_values.items()}
obs_rvs_to_values = {memo.get(k, k): v for k, v in obs_rvs_to_values.items()}

sampler_ir_db.query("+basic").rewrite(fgraph)

Expand Down Expand Up @@ -186,10 +186,28 @@ def __init__(self, inputs, outputs, *args, **kwargs):
# self.destroy_map = self.elemwise_op.destroy_map
self.ufunc = None
self.nfunc = None
OpFromGraph.__init__(self, inputs, outputs, *args, **kwargs)

used_inputs = [inp for inp in inputs if not isinstance(inp, Constant)]

OpFromGraph.__init__(self, used_inputs, outputs, *args, **kwargs)

def make_node(self, *inputs):
node = super().make_node(*inputs)
# Remove constants
used_inputs = [inp for inp in inputs if not isinstance(inp, Constant)]

# TODO: We could make sure that the new constant inputs correspond to
# the originals...

# The user interface doesn't expect the shared variable inputs of the
# inner-graph, but, since `Op.make_node` does (and `Op.__call__`
# dispatches to `Op.make_node`), we need to compensate here
num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs)

if len(used_inputs) == num_expected_inps:
used_inputs = used_inputs + self.shared_inputs

node = super().make_node(*used_inputs)

# Remove shared variable inputs. We aren't going to compute anything
# with this `Op`, so they're not needed
real_inputs = node.inputs[: len(node.inputs) - len(self.shared_inputs)]
Expand Down Expand Up @@ -346,8 +364,6 @@ def local_elemwise_dimshuffle_subsume(fgraph, node):

new_out = new_op(*new_inputs)

assert len(new_out.owner.inputs) == len(node.inputs)

return new_out.owner.outputs


Expand Down
28 changes: 27 additions & 1 deletion tests/test_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from etuples import etuple, etuplize
from unification import unify

from aemcmc.rewriting import SubsumingElemwise, local_elemwise_dimshuffle_subsume
from aemcmc.rewriting import (
SubsumingElemwise,
construct_ir_fgraph,
local_elemwise_dimshuffle_subsume,
)


def test_SubsumingElemwise_basics():
Expand Down Expand Up @@ -107,3 +111,25 @@ def test_local_elemwise_dimshuffle_subsume_transpose():
# The input corresponding to `b`/`b_ds` should be equivalent to `b.T`
assert isinstance(res.owner.inputs[1].owner.op, DimShuffle)
assert equal_computations([b.T], [res.owner.inputs[1]])


def test_SubsumingElemwise_constant_inputs():
"""Make sure constant inputs are handled correctly by `SubsumingElemwise`."""

srng = at.random.RandomStream(0)

s = at.lscalar("s")
# The `1` is the constant input to a `true_div` `Elemwise` that should be
# "subsumed"
Z = srng.exponential(1, size=s, name="Z")
mu = 1 / Z
Y = srng.normal(mu, name="Y")
y = Y.clone()
y.name = "y"

res, *_ = construct_ir_fgraph({Y: y}, clone=False)

normal_node = res.outputs[1].owner
subelem_node = normal_node.inputs[3].owner
assert isinstance(subelem_node.op, SubsumingElemwise)
assert subelem_node.inputs == [Z]