From e7da95ccdbc1009acc0b070b095fa3c86ec3d715 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 26 Jan 2021 15:56:47 -0500 Subject: [PATCH] Directly convert IndependentDistribution to base distribution --- funsor/distribution.py | 29 ++++++++++++++++++++++------- test/test_distribution_generic.py | 20 ++++++++++++-------- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index 81af49460..6f7492ce2 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -19,9 +19,9 @@ from funsor.domains import Array, Real, Reals from funsor.gaussian import Gaussian from funsor.interpreter import gensym -from funsor.tensor import (Tensor, align_tensors, dummy_numeric_array, get_default_prototype, +from funsor.tensor import (Function, Tensor, align_tensors, dummy_numeric_array, get_default_prototype, ignore_jit_warnings, numeric_array, stack) -from funsor.terms import Funsor, FunsorMeta, Independent, Number, Variable, \ +from funsor.terms import Funsor, FunsorMeta, Independent, Lambda, Number, Variable, \ eager, reflect, to_data, to_funsor from funsor.util import broadcast_shape, get_backend, getargspec, lazy_property @@ -373,14 +373,29 @@ def backenddist_to_funsor(funsor_dist_class, backend_dist, output=None, dim_to_n def indepdist_to_funsor(backend_dist, output=None, dim_to_name=None): + if dim_to_name is None: + dim_to_name = {} + event_dim_to_name = OrderedDict((i, "_pyro_event_dim_{}".format(i)) + for i in range(-backend_dist.reinterpreted_batch_ndims, 0)) dim_to_name = OrderedDict((dim - backend_dist.reinterpreted_batch_ndims, name) for dim, name in dim_to_name.items()) - dim_to_name.update(OrderedDict((i, "_pyro_event_dim_{}".format(i)) - for i in range(-backend_dist.reinterpreted_batch_ndims, 0))) + dim_to_name.update(event_dim_to_name) result = to_funsor(backend_dist.base_dist, dim_to_name=dim_to_name) - for i in reversed(range(-backend_dist.reinterpreted_batch_ndims, 0)): - name = "_pyro_event_dim_{}".format(i) - result = funsor.terms.Independent(result, "value", name, "value") + if isinstance(result, Distribution) and \ + not isinstance(result.value, Function): # Function used in some eager patterns + params = tuple(result.params.values())[:-1] + for dim, name in reversed(event_dim_to_name.items()): + dim_var = to_funsor(name, result.inputs[name]) + params = tuple(Lambda(dim_var, param) for param in params) + if isinstance(result.value, Variable): + # broadcasting logic in Distribution will compute correct value domain + result = type(result)(*(params + (result.value.name,))) + else: + raise NotImplementedError("TODO support converting Indep(Transform)") + else: + # this handles the output of eager rewrites, e.g. Normal->Gaussian or Beta->Dirichlet + for dim, name in reversed(event_dim_to_name.items()): + result = funsor.terms.Independent(result, "value", name, "value") return result diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index a2a3d21e8..802f857ff 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -52,16 +52,20 @@ def __getattribute__(self, attr): backend_dist.ExpandedDistribution = backend_dist.torch_distribution.ExpandedDistribution -def normalize_with_subs(cls, *args): +def eager_no_dists(cls, *args): """ - This interpretation is like normalize, except it also evaluates Subs eagerly. + This interpretation is like eager, except it skips special distribution patterns. This is necessary because we want to convert distribution expressions to normal form in some tests, but do not want to trigger eager patterns that rewrite some distributions (e.g. Normal to Gaussian) since these tests are specifically intended to exercise funsor.distribution.Distribution. """ - result = normalize.dispatch(cls, *args)(*args) + if issubclass(cls, funsor.distribution.Distribution) and not isinstance(args[-1], funsor.Tensor): + return reflect(cls, *args) + result = eager.dispatch(cls, *args)(*args) + if result is None: + result = normalize.dispatch(cls, *args)(*args) if result is None: result = lazy.dispatch(cls, *args)(*args) if result is None: @@ -558,7 +562,7 @@ def test_generic_distribution_to_funsor(case): expected_value_domain = case.expected_value_domain dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) - with interpretation(normalize_with_subs): + with interpretation(eager_no_dists): funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) assert funsor_dist.inputs["value"] == expected_value_domain @@ -592,7 +596,7 @@ def test_generic_log_prob(case, use_lazy): expected_value_domain = case.expected_value_domain dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) - with interpretation(normalize_with_subs if use_lazy else eager): + with interpretation(eager_no_dists if use_lazy else eager): # some distributions have nontrivial eager patterns funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) expected_inputs = {name: funsor.Bint[raw_dist.batch_shape[dim]] for dim, name in dim_to_name.items()} @@ -615,7 +619,7 @@ def test_generic_enumerate_support(case, expand): raw_dist = case.get_dist() dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) - with interpretation(normalize_with_subs): + with interpretation(eager_no_dists): funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) assert getattr(raw_dist, "has_enumerate_support", False) == getattr(funsor_dist, "has_enumerate_support", False) @@ -633,7 +637,7 @@ def test_generic_sample(case, sample_shape): raw_dist = case.get_dist() dim_to_name, name_to_dim = _default_dim_to_name(sample_shape + raw_dist.batch_shape) - with interpretation(normalize_with_subs): + with interpretation(eager_no_dists): funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) sample_inputs = OrderedDict((dim_to_name[dim - len(raw_dist.batch_shape)], funsor.Bint[sample_shape[dim]]) @@ -655,7 +659,7 @@ def test_generic_stats(case, statistic): raw_dist = case.get_dist() dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) - with interpretation(normalize_with_subs): + with interpretation(eager_no_dists): funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) with xfail_if_not_implemented(msg="entropy not implemented for some distributions"), \