Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automate distribution testing #389

Merged
merged 40 commits into from
Nov 11, 2020
Merged

Automate distribution testing #389

merged 40 commits into from
Nov 11, 2020

Conversation

eb8680
Copy link
Member

@eb8680 eb8680 commented Oct 30, 2020

Addresses #386. Blocked by #388.

This PR adds a new file test_distribution_generic.py that refactors the distribution tests. Here there is one generic test for each distribution method, and all that is needed to test a new distribution is to add a recipe for constructing a random instance to a list of test cases. This change is necessary if we want to approach full coverage of the dozens of distributions in PyTorch, Pyro/NumPyro and TFP without thousands of lines of manually duplicated testing logic.

I have also had to add a number of small fixes to get the new tests to pass, which I would argue is a sign of the value of these tests - even the first version here is exercising many small distribution API edge cases that would be hard to catch with the current approach. There are also a couple of new testing utilities (ops.isnan and testing.random_scale_tril).

If the approach in this PR works, I will delete many of the one-off tests in test_distribution.py in a followup PR.

Remaining tasks:

  • Get generic gradient test running
  • Add enough test cases to reach coverage parity with the tests in test_distribution.py
  • Get all deterministic tests to pass and avoid nans

Triaged

  • Get all Monte Carlo tests to pass - I am skeptical of the viability and usefulness of doing this and have simply disabled and removed these tests for now.

@eb8680 eb8680 added Blocked Blocked by other issues WIP refactor testing labels Oct 30, 2020
@eb8680 eb8680 changed the base branch from master to distribution-stats October 30, 2020 17:28
Base automatically changed from distribution-stats to master October 30, 2020 23:47
@eb8680 eb8680 mentioned this pull request Oct 31, 2020
34 tasks
@eb8680 eb8680 removed the Blocked Blocked by other issues label Oct 31, 2020
@eb8680 eb8680 added awaiting review and removed WIP labels Nov 4, 2020
@eb8680 eb8680 requested a review from fritzo November 4, 2020 00:19
"variance",
pytest.param("entropy", marks=[pytest.mark.skipif(get_backend() == "jax", reason="entropy not implemented")])
])
def test_generic_stats_sample(case, statistic, sample_shape):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test comparing Monte Carlo estimates of summary statistics with ground-truth values is slow (especially on the JAX backend) and finicky. I've disabled it by default but could also remove it entirely - I'm not sure failures are very informative.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed these tests from this PR.

"variance",
pytest.param("entropy", marks=[pytest.mark.skipif(get_backend() == "jax", reason="entropy not implemented")])
])
def test_generic_grads_sample(case, statistic, sample_shape):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto for this test comparing Monte Carlo estimates of gradients of summary statistics wrt parameters versus gradients of ground-truth values - it's slow and finicky, and I'm not sure failures are very informative. It's disabled by default, but I am open to removing it entirely.

Comment on lines 68 to 72
TEST_CASES += [DistTestCase(
"backend_dist.Bernoulli(logits=case.logits)",
(("logits", f"rand({batch_shape})"),),
funsor.Real,
)]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this PR, adding a new distribution to funsor.distributions will be as simple as adding it to the list of distributions to be wrapped in funsor/{backend}/distributions.py and adding a new test case to this list.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding it to the list

Is there a .register() command allowing users to dynamically extend the list in their own code? Similar to the dynamic registration mechanisms in kl_divergence.register() or biject_to.register? Even better, could we automatically register distributions in Funsor the first time they are encountered?

Copy link
Member Author

@eb8680 eb8680 Nov 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a .register() command allowing users to dynamically extend the list in their own code?

Yes, funsor.distribution.make_dist basically plays this role, especially after #391 - it takes a backend distribution class and (optionally) some parameter names as input and generates a new funsor.distribution.Distribution with generic eager_subs, unscaled_sample, to_funsor and to_data patterns that should work for most use cases, provided the user has correctly implemented .arg_constraints and .support in their custom distribution.


# BernoulliLogits
TEST_CASES += [DistTestCase(
"backend_dist.Bernoulli(logits=case.logits)",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note I've chosen to encode test cases with strings of Python code. This should allow us to apply these generic tests even to quite complicated backend distribution expressions, e.g. TransformedDistribution(Normal(loc, scale).to_event(1), [TanhTransform,]).mask(mask)


result = funsor.delta.Delta(value.name, Tensor(raw_sample, inputs, value.output.dtype))
funsor_value = to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to update funsor.distribution.Distribution.unscaled_sample to use to_funsor and to_data throughout to get some tests to pass.

@eb8680
Copy link
Member Author

eb8680 commented Nov 4, 2020

@fehiepsi Travis seems to be hanging somewhere in the JAX tests, but I haven't been able to reproduce it locally and I can't figure out where from the logs, although I assume it's related to the distribution tests. Any idea what might be going on?

@eb8680
Copy link
Member Author

eb8680 commented Nov 4, 2020

One thing I'm noticing is that test_dirichlet_sample and the other sampler tests in test/test_distributions.py are using huge amounts of memory under the JAX backend both in absolute terms and relative to the PyTorch backend, and may be leaking memory. This was true before this PR but maybe it got worse? (edit: no, seems about the same)

The new sampler tests in this PR (test_generic_sample in test_distribution_generic.py) are basically smoke tests that should in theory be much less expensive, but I'm still seeing a huge gap in runtime and memory usage between the PyTorch and JAX backends. Maybe they're the tipping point...

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...still reviewing test_distribution_generic.py. I think it will be easier to read after adding TEST_CASES.append(self) to .__init__().

def multinomial_to_data(funsor_dist, name_to_dim=None):
probs = to_data(funsor_dist.probs, name_to_dim)
total_count = to_data(funsor_dist.total_count, name_to_dim)
if isinstance(total_count, numbers.Number) or len(total_count.shape) == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to worry about int(total_count) thwarting PyTorch tracing? Should we preserve scalar Tensors?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this change because torch.distributions.Multinomial raises an error given a tensor total_count, even if it is a scalar tensor, and the generic conversion tests were failing. It's definitely an issue, but there are already lots of JIT issues with scalars and tuple shapes throughout the codebase (encountered while working on pyro.contrib.funsor), and my inclination with this PR and further work on distributions is to avoid special-casing to the greatest extent possible even if that means deferring to odd implementation details in the backends. I think a better fix in this instance would be to allow scalar tensor total_count upstream so this change could be reverted.

This is also a good reminder to add some generic JIT tests for distribution wrappers in a followup PR.

test/test_distribution_generic.py Show resolved Hide resolved
@fehiepsi
Copy link
Member

fehiepsi commented Nov 4, 2020

@eb8680 I think the memory issue is due to "caching" mechanism of jax. It cached the compiled code so that the next time, a function can be executed fast given inputs with the same shape. I guess we can split

pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi

into smaller procedures, e.g.

pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi --ignore=test/distribution.py
pytest -v -n auto test/distribution.py

@eb8680
Copy link
Member Author

eb8680 commented Nov 4, 2020

Would that also explain the run time differences? For comparison, the new sampler smoke tests in this PR were taking ~1s on a single core on my laptop for PyTorch vs ~100s for JAX

@fehiepsi
Copy link
Member

fehiepsi commented Nov 4, 2020

How many tests there are? If there are 100 tests, then it is normal to me.

@eb8680
Copy link
Member Author

eb8680 commented Nov 4, 2020

How many tests there are? If there are 100 tests, then it is normal to me.

Around that, yeah

@fritzo fritzo merged commit 54e28ab into master Nov 11, 2020
@fritzo fritzo deleted the distribution-test-harness branch November 11, 2020 17:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants