Skip to content

Commit

Permalink
Derive probability for transforms with implicit broadcasting
Browse files Browse the repository at this point in the history
A warning is issued as this graph is unlikely to be desired for most users.
  • Loading branch information
ricardoV94 committed Jun 30, 2023
1 parent 4bd8c30 commit 0ec578c
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 9 deletions.
30 changes: 27 additions & 3 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
# SOFTWARE.

import abc
import warnings

from copy import copy
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -111,6 +112,7 @@
cleanup_ir_rewrites_db,
measurable_ir_rewrites_db,
)
from pymc.logprob.shape import measurable_broadcast
from pymc.logprob.utils import check_potential_measurability


Expand Down Expand Up @@ -564,6 +566,9 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li

scalar_op = node.op.scalar_op
measurable_input_idx = 0
measurable_input_broadcast = (

Check warning on line 569 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L569

Added line #L569 was not covered by tests
measurable_input.type.broadcastable != node.default_output().type.broadcastable
)
transform_inputs: Tuple[TensorVariable, ...] = (measurable_input,)
transform: RVTransform

Expand All @@ -588,22 +593,41 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
return None
try:
(power,) = other_inputs
power = pt.get_underlying_scalar_constant_value(power).item()
base_power = pt.get_underlying_scalar_constant_value(power).item()

Check warning on line 596 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L596

Added line #L596 was not covered by tests
# Power needs to be a constant
except NotScalarConstantError:
return None
transform_inputs = (measurable_input, power)
transform = PowerTransform(power=power)
transform = PowerTransform(power=base_power)

Check warning on line 601 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L601

Added line #L601 was not covered by tests
elif isinstance(scalar_op, Add):
transform_inputs = (measurable_input, pt.add(*other_inputs))
transform = LocTransform(
transform_args_fn=lambda *inputs: inputs[-1],
)
elif transform is None:
elif isinstance(scalar_op, Mul):

Check warning on line 607 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L607

Added line #L607 was not covered by tests
transform_inputs = (measurable_input, pt.mul(*other_inputs))
transform = ScaleTransform(
transform_args_fn=lambda *inputs: inputs[-1],
)
else:
raise TypeError(
f"Scalar Op not supported: {scalar_op}. Rewrite should not have been triggered"
) # pragma: no cover

if measurable_input_broadcast:

Check warning on line 617 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L617

Added line #L617 was not covered by tests
# This rewrite logic only supports broadcasting for transforms with two inputs, where the first is measurable.
# This covers all current cases, update if other cases are supported in the future.
if len(transform_inputs) != 2 or measurable_input_idx != 0:
return None
warnings.warn(

Check warning on line 622 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L620-L622

Added lines #L620 - L622 were not covered by tests
"MeasurableTransform with implicit broadcasting detected. This corresponds to a potentially degenerate probability graph.\n"
"If you did not intend this, make sure the base measurable variable is created with all the dimensions from the start."
"Otherwise, an explicit `broadcast_to` operation can be used to silence this warning.\n",
UserWarning,
)
measurable_input, other_input = transform_inputs
measurable_input = measurable_broadcast(measurable_input, other_input.shape)
transform_inputs = (measurable_input, other_input)

Check warning on line 630 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L628-L630

Added lines #L628 - L630 were not covered by tests

transform_op = MeasurableTransform(
scalar_op=scalar_op,
Expand Down
32 changes: 26 additions & 6 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,16 +807,36 @@ def test_discrete_rv_multinary_transform_fails():
conditional_logp({y_rv: y_rv.clone()})


@pytest.mark.xfail(reason="Check not implemented yet")
def test_invalid_broadcasted_transform_rv_fails():
@pytest.mark.filterwarnings("error") # Fail if unexpected warning is issued
@pytest.mark.parametrize("implicit_broadcast", (True, False))
def test_broadcasted_transform_rv(implicit_broadcast):
loc = pt.vector("loc")
y_rv = loc + pt.random.normal(0, 1, size=1, name="base_rv")
base_rv = pt.random.normal(0, 1, size=1, name="base_rv")
if implicit_broadcast:
y_rv = loc + base_rv
else:
y_rv = loc + pt.broadcast_to(base_rv, shape=loc.shape)
y_rv.name = "y"
y_vv = y_rv.clone()

# This logp derivation should fail or count only once the values that are broadcasted
logprob = logp(y_rv, y_vv)
assert logprob.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]}).shape == ()
if implicit_broadcast:
with pytest.warns(UserWarning, match="implicit broadcasting detected"):
logprob = logp(y_rv, y_vv)
else:
logprob = logp(y_rv, y_vv)
logprob_fn = pytensor.function([loc, y_vv], logprob)

# All values must have the same offset from `loc`
np.testing.assert_allclose(
logprob_fn([1, 1, 1, 1], [0, 0, 0, 0]), sp.stats.norm.logpdf([0], loc=1)
)
np.testing.assert_allclose(
logprob_fn([1, 2, 3, 4], [0, 1, 2, 3]), sp.stats.norm.logpdf([0], loc=1)
)

# Otherwise probability is 0
np.testing.assert_array_equal(logprob_fn([1, 1, 1, 1], [0, 0, 0, 1]), [-np.inf])
np.testing.assert_array_equal(logprob_fn([1, 2, 3, 4], [0, 0, 0, 0]), [-np.inf])


@pytest.mark.parametrize("numerator", (1.0, 2.0))
Expand Down

0 comments on commit 0ec578c

Please sign in to comment.