diff --git a/setup.py b/setup.py index bc89799..0c00078 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setuptools.setup( name='zuko', - version='0.0.7', + version='0.0.8', packages=setuptools.find_packages(), description='Normalizing flows in PyTorch', keywords=[ diff --git a/tests/test_distributions.py b/tests/test_distributions.py index e2a3cc5..54476af 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -3,12 +3,13 @@ import pytest import torch +from torch.distributions import * from zuko.distributions import * def test_distributions(): ds = [ - NormalizingFlow([ExpTransform()], Gamma(2.0, 1.0)), + NormalizingFlow(ExpTransform(), Gamma(2.0, 1.0)), Joint(Uniform(0.0, 1.0), Normal(0.0, 1.0)), GeneralizedNormal(2.0), DiagNormal(torch.zeros(2), torch.ones(2)), diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 5d89af0..70da322 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -4,6 +4,7 @@ import torch from torch import randn +from torch.distributions import * from zuko.transforms import * diff --git a/zuko/distributions.py b/zuko/distributions.py index 0a977c3..ace2426 100644 --- a/zuko/distributions.py +++ b/zuko/distributions.py @@ -1,5 +1,19 @@ r"""Parameterizable probability distributions.""" +__all__ = [ + 'NormalizingFlow', + 'Joint', + 'GeneralizedNormal', + 'DiagNormal', + 'BoxUniform', + 'TransformedUniform', + 'Truncated', + 'Sort', + 'TopK', + 'Minimum', + 'Maximum', +] + import math import torch @@ -17,14 +31,12 @@ class NormalizingFlow(Distribution): r"""Creates a normalizing flow for a random variable :math:`X` towards a base - distribution :math:`p(Z)` through a series of :math:`n` invertible and differentiable - transformations :math:`f_1, f_2, \dots, f_n`. + distribution :math:`p(Z)` through a transformation :math:`f`. The density of a realization :math:`x` is given by the change of variables - .. math:: p(X = x) = p(Z = f(x)) \left| \det \frac{\partial f(x)}{\partial x} \right| + .. math:: p(X = x) = p(Z = f(x)) \left| \det \frac{\partial f(x)}{\partial x} \right| . - where :math:`f = f_1 \circ \dots \circ f_n` is the transformations' composition. To sample from :math:`p(X)`, realizations :math:`z \sim p(Z)` are mapped through the inverse transformation :math:`g = f^{-1}`. @@ -36,34 +48,38 @@ class NormalizingFlow(Distribution): | https://arxiv.org/abs/1912.02762 Arguments: - transforms: A list of transformations :math:`f_i`. + transforms: A transformation :math:`f`. base: A base distribution :math:`p(Z)`. Example: - >>> d = NormalizingFlow([ExpTransform()], Gamma(2.0, 1.0)) + >>> d = NormalizingFlow(ExpTransform(), Gamma(2.0, 1.0)) >>> d.sample() tensor(1.1316) """ + has_rsample = True + def __init__( self, - transforms: List[Transform], + transform: Transform, base: Distribution, ): super().__init__() - codomain_dim = ComposeTransform(transforms).codomain.event_dim - reinterpreted = codomain_dim - len(base.event_shape) + reinterpreted = transform.codomain.event_dim - len(base.event_shape) if reinterpreted > 0: base = Independent(base, reinterpreted) - self.transforms = transforms + self.transform = transform self.base = base + self.reinterpreted = max(-reinterpreted, 0) def __repr__(self) -> str: - lines = [f'({i + 1}): {t}' for i, t in enumerate(self.transforms)] - lines.append(f'(base): {self.base}') + lines = [ + f'(transform): {self.transform}', + f'(base): {self.base}', + ] lines = indent('\n'.join(lines), ' ') return self.__class__.__name__ + '(\n' + lines + '\n)' @@ -74,53 +90,31 @@ def batch_shape(self) -> Size: @property def event_shape(self) -> Size: - shape = self.base.event_shape - - for t in reversed(self.transforms): - shape = t.inverse_shape(shape) - - return shape + return self.transform.inverse_shape(self.base.event_shape) def expand(self, batch_shape: Size, new: Distribution = None): new = self._get_checked_instance(NormalizingFlow, new) - new.transforms = self.transforms + new.transform = self.transform new.base = self.base.expand(batch_shape) + new.reinterpreted = self.reinterpreted - Distribution.__init__(new, batch_shape=batch_shape, validate_args=False) + Distribution.__init__(new, validate_args=False) return new def log_prob(self, x: Tensor) -> Tensor: - acc = 0 - event_dim = len(self.event_shape) - - for t in self.transforms: - x, ladj = t.call_and_ladj(x) - acc = acc + _sum_rightmost(ladj, event_dim - t.domain.event_dim) - event_dim += t.codomain.event_dim - t.domain.event_dim + z, ladj = self.transform.call_and_ladj(x) + ladj = _sum_rightmost(ladj, self.reinterpreted) - return self.base.log_prob(x) + acc - - @property - def has_rsample(self) -> bool: - return self.base.has_rsample + return self.base.log_prob(z) + ladj def rsample(self, shape: Size = ()): - x = self.base.rsample(shape) - - for t in reversed(self.transforms): - x = t.inv(x) - - return x - - def sample(self, shape: Size = ()): - with torch.no_grad(): - x = self.base.sample(shape) - - for t in reversed(self.transforms): - x = t.inv(x) + if self.base.has_rsample: + z = self.base.rsample(shape) + else: + z = self.base.sample(shape) - return x + return self.transform.inv(z) class Joint(Distribution): @@ -334,7 +328,7 @@ class TransformedUniform(NormalizingFlow): """ def __init__(self, f: Transform, lower: Tensor, upper: Tensor): - super().__init__([f], Uniform(*map(f, map(torch.as_tensor, (lower, upper))))) + super().__init__(f, Uniform(*map(f, map(torch.as_tensor, (lower, upper))))) def expand(self, batch_shape: Size, new: Distribution = None) -> Distribution: new = self._get_checked_instance(TransformedUniform, new) @@ -372,7 +366,7 @@ def __init__( ): super().__init__(batch_shape=base.batch_shape) - assert len(base.event_shape) < 1, "'base' has to be univariate" + assert not base.event_shape, "'base' has to be univariate" self.base = base self.uniform = Uniform(base.cdf(lower), base.cdf(upper)) @@ -430,7 +424,7 @@ def __init__( ): super().__init__(batch_shape=base.batch_shape) - assert len(base.event_shape) < 1, "'base' has to be univariate" + assert not base.event_shape, "'base' has to be univariate" self.base = base self.n = n diff --git a/zuko/flows.py b/zuko/flows.py index f208117..0e897e3 100644 --- a/zuko/flows.py +++ b/zuko/flows.py @@ -22,6 +22,7 @@ from functools import partial from math import ceil from torch import Tensor, LongTensor, Size +from torch.distributions import * from typing import * from .distributions import * @@ -89,14 +90,14 @@ def forward(self, y: Tensor = None) -> NormalizingFlow: A normalizing flow :math:`p(X | y)`. """ - transforms = [t(y) for t in self.transforms] + transform = ComposedTransform(*(t(y) for t in self.transforms)) if y is None: base = self.base(y) else: base = self.base(y).expand(y.shape[:-1]) - return NormalizingFlow(transforms, base) + return NormalizingFlow(transform, base) class Unconditional(nn.Module): diff --git a/zuko/nn.py b/zuko/nn.py index 1c3667e..99a0d81 100644 --- a/zuko/nn.py +++ b/zuko/nn.py @@ -39,7 +39,7 @@ def forward(self, x: Tensor) -> Tensor: class MLP(nn.Sequential): r"""Creates a multi-layer perceptron (MLP). - Also known as fully connected feedforward network, an MLP is a series of + Also known as fully connected feedforward network, an MLP is a sequence of non-linear parametric functions .. math:: h_{i + 1} = a_{i + 1}(h_i W_{i + 1}^T + b_{i + 1}), diff --git a/zuko/transforms.py b/zuko/transforms.py index 730299b..d9796f4 100644 --- a/zuko/transforms.py +++ b/zuko/transforms.py @@ -1,12 +1,30 @@ r"""Parameterizable transformations.""" +__all__ = [ + 'ComposedTransform', + 'IdentityTransform', + 'CosTransform', + 'SinTransform', + 'SoftclipTransform', + 'MonotonicAffineTransform', + 'MonotonicRQSTransform', + 'MonotonicTransform', + 'UnconstrainedMonotonicTransform', + 'SOSPolynomialTransform', + 'FFJTransform', + 'AutoregressiveTransform', + 'PermutationTransform', +] + import math import torch import torch.nn.functional as F -from torch import Tensor, LongTensor +from textwrap import indent +from torch import Tensor, LongTensor, Size from torch.distributions import * from torch.distributions import constraints +from torch.distributions.utils import _sum_rightmost from typing import * from .utils import bisection, broadcast, gauss_legendre, odeint @@ -28,6 +46,97 @@ def _call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: Transform.call_and_ladj = _call_and_ladj +class ComposedTransform(Transform): + r"""Creates a transformation :math:`f(x) = f_n \circ \dots \circ f_0(x)`. + + Arguments: + transforms: A sequence of transformations :math:`f_i`. + """ + + def __init__(self, *transforms: Transform, **kwargs): + super().__init__(**kwargs) + + assert transforms, "'transforms' cannot be empty" + + event_dim = 0 + + for t in reversed(transforms): + event_dim = t.domain.event_dim + max(event_dim - t.codomain.event_dim, 0) + + self.domain_dim = event_dim + + for t in transforms: + event_dim += t.codomain.event_dim - t.domain.event_dim + + self.codomain_dim = event_dim + self.transforms = transforms + + def __repr__(self) -> str: + lines = [f'({i}): {t}' for i, t in enumerate(self.transforms)] + lines = indent('\n'.join(lines), ' ') + + return f'{self.__class__.__name__}(\n' + lines + '\n)' + + @property + def domain(self) -> constraints.Constraint: + domain = self.transforms[0].domain + reinterpreted = self.domain_dim - domain.event_dim + + if reinterpreted > 0: + return constraints.independent(domain, reinterpreted) + else: + return domain + + @property + def codomain(self) -> constraints.Constraint: + codomain = self.transforms[-1].codomain + reinterpreted = self.codomain_dim - codomain.event_dim + + if reinterpreted > 0: + return constraints.independent(codomain, reinterpreted) + else: + return codomain + + @property + def bijective(self) -> bool: + return all(t.bijective for t in self.transforms) + + def _call(self, x: Tensor) -> Tensor: + for t in self.transforms: + x = t(x) + return x + + def _inverse(self, y: Tensor) -> Tensor: + for t in reversed(self.transforms): + y = t.inv(y) + return y + + def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: + _, ladj = self.call_and_ladj(x) + return ladj + + def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: + event_dim = self.domain_dim + acc = 0 + + for t in self.transforms: + x, ladj = t.call_and_ladj(x) + acc = acc + _sum_rightmost(ladj, event_dim - t.domain.event_dim) + event_dim += t.codomain.event_dim - t.domain.event_dim + + return x, acc + + def forward_shape(self, shape: Size) -> Size: + for t in self.transforms: + shape = t.forward_shape(shape) + return shape + + def inverse_shape(self, shape: Size) -> Size: + for t in reversed(self.transforms): + shape = t.inverse_shape(shape) + return shape + + class IdentityTransform(Transform): r"""Creates a transformation :math:`f(x) = x`.""" diff --git a/zuko/utils.py b/zuko/utils.py index 843cce0..b9f4f45 100644 --- a/zuko/utils.py +++ b/zuko/utils.py @@ -6,7 +6,6 @@ import numpy as np import torch -import torch.nn as nn from functools import lru_cache from torch import Tensor @@ -22,7 +21,7 @@ def bisection( phi: Iterable[Tensor] = (), ) -> Tensor: r"""Applies the bisection method to find :math:`x` between the bounds :math:`a` - an :math:`b` such that :math:`f_\phi(x)` is close to :math:`y`. + and :math:`b` such that :math:`f_\phi(x)` is close to :math:`y`. Gradients are propagated through :math:`y` and :math:`\phi` via implicit differentiation.