Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to sqrt(prescision) representation in Gaussian? #567

Closed
3 tasks done
fritzo opened this issue Oct 6, 2021 · 0 comments · Fixed by #568
Closed
3 tasks done

Switch to sqrt(prescision) representation in Gaussian? #567

fritzo opened this issue Oct 6, 2021 · 0 comments · Fixed by #568

Comments

@fritzo
Copy link
Member

fritzo commented Oct 6, 2021

Addresses #559

This issue proposes switching the Gaussian parameters from (info_vec, precision) to (info_vec, prec_sqrt), following @fehiepsi's work in pyro-ppl/pyro#2019.

Motivation

Our original motivation for representing Gaussians as (info_vec, precision) was to support operations on rank-deficient precision matrices, which occur in conditional probability distributions. This design choice allows us to uniformly handle priors, conditional distributions, and likelihoods by treating all three agnostically as mere factors in a factor graph.

However while the (info_vec, precision) representation is numerically stable, it is computationally inefficient when making low-dimensional observations of a high-dimensional variable. For example to store a conditional distribution of a 1-dimensional variable given a 1000-dimensional variable, the precision matrix has 1001**2 elements, but since it has rank 1 its square root would cost only 1001 elements. Indeed we recognized this pyro-ppl/pyro#2005 and #217 and created a special AffineNormal pattern to avoid materializing rank-1 precision matrices.

An alternative representation is the classic square root information filter (SRIF), explored by @fehiepsi in pyro-ppl/pyro#2019. This represents a Gaussian as a pair (info_vec, prec_sqrt), of shapes batch_shape + (dim,) and batch_shape + (dim, rank) respectively, so that

precision = prec_sqrt @ prec_sqrt.transpose(-2, -1)

Advantages of the square root information representation include:

  • space storage cost is O(dim * rank) which can be much smaller than O(dim ** 2)
  • ops.add can be implemented by mere concatenation (optionally followed by compression if rank > dim or maybe rank > 1.5 * dim)
  • we could drop AffineNormal

Design questions

  • Should we make this change?
  • If we change Gaussian interface what will break?
    • the arXiv paper will be out of date
    • I believe all Pyro & NumPyro code uses to_funsor(), so will be unaffected
  • How should we make this change? Possibilities include:
    • 👍 Simply replace the existing Gaussian.
    • Build a separate GaussianS.
    • Create a new ConditionalGaussian as the basic thing and make Gaussian a mere alias
      (but how do we unify priors, conditionals, and likelihoods?)
    • Create an abstraction for structured precision matrices as in GPyTorch.
      (or is it cleaner to allow structured square root matrices, since they have natural block structure?)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant