Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generic stat methods to Distribution #388

Merged
merged 1 commit into from
Oct 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,22 @@ def eager_log_prob(cls, *params):
data = cls.dist_class(**params).log_prob(value)
return Tensor(data, inputs)

def _get_raw_dist(self):
"""
Internal method for working with underlying distribution attributes
"""
if isinstance(self.value, Variable):
value_name = self.value.name
else:
raise NotImplementedError("cannot get raw dist for {}".format(self))
# arbitrary name-dim mapping, since we're converting back to a funsor anyway
name_to_dim = {name: -dim-1 for dim, (name, domain) in enumerate(self.inputs.items())
if isinstance(domain.dtype, int) and name != value_name}
raw_dist = to_data(self, name_to_dim=name_to_dim)
dim_to_name = {dim: name for name, dim in name_to_dim.items()}
# also return value output, dim_to_name for converting results back to funsor
return raw_dist, self.value.output, dim_to_name

@property
def has_rsample(self):
return getattr(self.dist_class, "has_rsample", False)
Expand Down Expand Up @@ -139,16 +155,26 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
return result

def enumerate_support(self, expand=False):
if not self.has_enumerate_support or not isinstance(self.value, Variable):
raise ValueError("cannot enumerate support of {}".format(repr(self)))
# arbitrary name-dim mapping, since we're converting back to a funsor anyway
name_to_dim = {name: -dim-1 for dim, (name, domain) in enumerate(self.inputs.items())
if isinstance(domain.dtype, int) and name != self.value.name}
raw_dist = to_data(self, name_to_dim=name_to_dim)
assert self.has_enumerate_support and isinstance(self.value, Variable)
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.enumerate_support(expand=expand)
dim_to_name = {dim: name for name, dim in name_to_dim.items()}
dim_to_name[min(dim_to_name.keys(), default=0)-1] = self.value.name
return to_funsor(raw_value, output=self.value.output, dim_to_name=dim_to_name)
return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)

def entropy(self):
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.entropy()
return to_funsor(raw_value, output=self.output, dim_to_name=dim_to_name)

def mean(self):
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.mean
return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)

def variance(self):
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.variance
return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)

def __getattribute__(self, attr):
if attr in type(self)._ast_fields and attr != 'name':
Expand Down
23 changes: 10 additions & 13 deletions test/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from funsor.domains import Bint, Real, Reals
from funsor.integrate import Integrate
from funsor.interpreter import interpretation, reinterpret
from funsor.tensor import Einsum, Tensor, align_tensors, numeric_array, stack
from funsor.tensor import Einsum, Tensor, numeric_array, stack
from funsor.terms import Independent, Variable, eager, lazy, to_funsor
from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_mvn, random_tensor, xfail_param
from funsor.util import get_backend
Expand Down Expand Up @@ -701,29 +701,26 @@ def _get_stat_diff(funsor_dist_class, sample_inputs, inputs, num_samples, statis
check_funsor(sample_value, expected_inputs, Real)

if sample_inputs:

actual_mean = Integrate(
sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value'])
).reduce(ops.add, frozenset(sample_inputs))

inputs, tensors = align_tensors(*list(funsor_dist.params.values())[:-1])
raw_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], tensors)))
expected_mean = Tensor(raw_dist.mean, inputs)

if statistic == "mean":
actual_stat, expected_stat = actual_mean, expected_mean
actual_stat = Integrate(
sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value'])
).reduce(ops.add, frozenset(sample_inputs))
expected_stat = funsor_dist.mean()
elif statistic == "variance":
actual_mean = Integrate(
sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value'])
).reduce(ops.add, frozenset(sample_inputs))
actual_stat = Integrate(
sample_value,
(Variable('value', funsor_dist.inputs['value']) - actual_mean) ** 2,
frozenset(['value'])
).reduce(ops.add, frozenset(sample_inputs))
expected_stat = Tensor(raw_dist.variance, inputs)
expected_stat = funsor_dist.variance()
elif statistic == "entropy":
actual_stat = -Integrate(
sample_value, funsor_dist, frozenset(['value'])
).reduce(ops.add, frozenset(sample_inputs))
expected_stat = Tensor(raw_dist.entropy(), inputs)
expected_stat = funsor_dist.entropy()
else:
raise ValueError("invalid test statistic")

Expand Down