From 2d5a37f83d8b392e21a6c808a5471c078a68d3e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Tue, 20 Dec 2022 16:26:13 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20New=20features?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add call_and_ladj method to transformations to improve flow efficiency * Add continuous normalizing flow (CNF) * Add free-form Jacobian (FFJ) transformation * Add odeint solver * Make bisection auto-differentiable --- setup.py | 2 +- tests/test_flows.py | 20 +- tests/test_transforms.py | 38 +++- tests/test_utils.py | 41 +++- zuko/distributions.py | 74 +++++++- zuko/flows.py | 147 +++++++++++++-- zuko/nn.py | 112 ++++++----- zuko/transforms.py | 184 ++++++++++++++---- zuko/utils.py | 394 ++++++++++++++++++++++++++++++++------- 9 files changed, 818 insertions(+), 194 deletions(-) diff --git a/setup.py b/setup.py index 33cc483..bc89799 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setuptools.setup( name='zuko', - version='0.0.6', + version='0.0.7', packages=setuptools.find_packages(), description='Normalizing flows in PyTorch', keywords=[ diff --git a/tests/test_flows.py b/tests/test_flows.py index 1cf56e7..5dd84d7 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -14,6 +14,7 @@ def test_flows(tmp_path): SOSPF(3, 5), NAF(3, 5), NAF(3, 5, unconstrained=True), + CNF(3, 5), ] for flow in flows: @@ -24,16 +25,27 @@ def test_flows(tmp_path): assert log_p.shape == (256,), flow assert log_p.requires_grad, flow + flow.zero_grad(set_to_none=True) loss = -log_p.mean() loss.backward() for p in flow.parameters(): - assert hasattr(p, 'grad'), flow + assert p.grad is not None, flow # Sampling - z = flow(y).sample((32,)) + x = flow(y).sample((32,)) - assert z.shape == (32, 3), flow + assert x.shape == (32, 3), flow + + # Reparameterization trick + x = flow(y).rsample() + + flow.zero_grad(set_to_none=True) + loss = x.square().sum().sqrt() + loss.backward() + + for p in flow.parameters(): + assert p.grad is not None, flow # Invertibility x, y = randn(256, 3), randn(256, 5) @@ -58,7 +70,9 @@ def test_flows(tmp_path): x, y = randn(3), randn(5) + seed = torch.seed() log_p = flow(y).log_prob(x) + torch.manual_seed(seed) log_p_bis = flow_bis(y).log_prob(x) assert torch.allclose(log_p, log_p_bis), flow diff --git a/tests/test_transforms.py b/tests/test_transforms.py index c1d3e22..5d89af0 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -21,8 +21,9 @@ def test_univariate_transforms(): ] for t in ts: + # Call if hasattr(t.domain, 'lower_bound'): - x = torch.linspace(t.domain.lower_bound, t.domain.upper_bound, 256) + x = torch.linspace(t.domain.lower_bound + 1e-2, t.domain.upper_bound - 1e-2, 256) else: x = torch.linspace(-5.0, 5.0, 256) @@ -30,6 +31,7 @@ def test_univariate_transforms(): assert x.shape == y.shape, t + # Inverse z = t.inv(y) assert torch.allclose(x, z, atol=1e-4), t @@ -42,7 +44,39 @@ def test_univariate_transforms(): ladj = torch.diag(J).abs().log() - assert torch.allclose(ladj, t.log_abs_det_jacobian(x, y), atol=1e-4), t + assert torch.allclose(t.log_abs_det_jacobian(x, y), ladj, atol=1e-4), t + + # Inverse Jacobian + J = torch.autograd.functional.jacobian(t.inv, y) + + assert (torch.triu(J, diagonal=1) == 0).all(), t + assert (torch.tril(J, diagonal=-1) == 0).all(), t + + ladj = torch.diag(J).abs().log() + + assert torch.allclose(t.inv.log_abs_det_jacobian(y, z), ladj, atol=1e-4), t + + +def test_FFJTransform(): + a = torch.randn(3) + f = lambda x, t: a * x + t = FFJTransform(f, time=torch.tensor(1.0)) + + # Call + x = randn(256, 3) + y = t(x) + + assert x.shape == y.shape + + # Inverse + z = t.inv(y) + + assert torch.allclose(x, z, atol=1e-4) + + # Jacobian + ladj = t.log_abs_det_jacobian(x, y) + + assert ladj.shape == x.shape[:-1] def test_PermutationTransform(): diff --git a/tests/test_utils.py b/tests/test_utils.py index ac04128..78fd828 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,14 +10,15 @@ def test_bisection(): f = torch.cos + y = torch.tensor(0.0) a = torch.rand(256, 1) + 2.0 b = torch.rand(16) - x = bisection(f, a, b, n=18) + x = bisection(f, y, a, b, n=18) assert x.shape == (256, 16) assert torch.allclose(x, torch.tensor(math.pi / 2), atol=1e-4) - assert torch.allclose(f(x), torch.tensor(0.0), atol=1e-4) + assert torch.allclose(f(x), y, atol=1e-4) def test_broadcast(): @@ -59,17 +60,37 @@ def test_gauss_legendre(): # Polynomial f = lambda x: x**5 - x**2 F = lambda x: x**6 / 6 - x**3 / 3 - a, b = randn(2, 256) + a, b = randn(2, 256, requires_grad=True) area = gauss_legendre(f, a, b, n=3) - assert torch.allclose(F(b) - F(a), area, atol=1e-4) + assert torch.allclose(area, F(b) - F(a), atol=1e-4) # Gradients - grad_a, grad_b = torch.autograd.functional.jacobian( - lambda a, b: gauss_legendre(f, a, b).sum(), - (a, b), - ) + grad_a, grad_b = torch.autograd.grad(area.sum(), (a, b)) - assert torch.allclose(-f(a), grad_a) - assert torch.allclose(f(b), grad_b) + assert torch.allclose(grad_a, -f(a), atol=1e-4) + assert torch.allclose(grad_b, f(b), atol=1e-4) + + +def test_odeint(): + # Linear + alpha = torch.tensor(1.0, requires_grad=True) + t = torch.tensor(3.0, requires_grad=True) + + f = lambda x, t: -alpha * x + F = lambda x, t: x * (-alpha * t).exp() + + x0 = randn(256, 1, requires_grad=True) + xt = odeint(f, x0, torch.zeros_like(t), t, phi=(alpha,)) + + assert xt.shape == x0.shape + assert torch.allclose(xt, F(x0, t), atol=1e-4) + + # Gradients + grad_x0, grad_t, grad_alpha = torch.autograd.grad(xt.sum(), (x0, t, alpha)) + g_x0, g_t, g_alpha = torch.autograd.grad(F(x0, t).sum(), (x0, t, alpha)) + + assert torch.allclose(grad_x0, g_x0, atol=1e-4) + assert torch.allclose(grad_t, g_t, atol=1e-4) + assert torch.allclose(grad_alpha, g_alpha, atol=1e-4) diff --git a/zuko/distributions.py b/zuko/distributions.py index 4862fda..0a977c3 100644 --- a/zuko/distributions.py +++ b/zuko/distributions.py @@ -1,4 +1,4 @@ -r"""Parametrizable probability distributions.""" +r"""Parameterizable probability distributions.""" import math import torch @@ -7,6 +7,7 @@ from torch import Tensor, Size from torch.distributions import * from torch.distributions import constraints +from torch.distributions.utils import _sum_rightmost from typing import * @@ -14,7 +15,7 @@ Distribution.arg_constraints = {} -class NormalizingFlow(TransformedDistribution): +class NormalizingFlow(Distribution): r"""Creates a normalizing flow for a random variable :math:`X` towards a base distribution :math:`p(Z)` through a series of :math:`n` invertible and differentiable transformations :math:`f_1, f_2, \dots, f_n`. @@ -49,18 +50,77 @@ def __init__( transforms: List[Transform], base: Distribution, ): - super().__init__(base, [t.inv for t in reversed(transforms)]) + super().__init__() + + codomain_dim = ComposeTransform(transforms).codomain.event_dim + reinterpreted = codomain_dim - len(base.event_shape) + + if reinterpreted > 0: + base = Independent(base, reinterpreted) + + self.transforms = transforms + self.base = base def __repr__(self) -> str: - lines = [f'({i+1}): {t.inv}' for i, t in enumerate(reversed(self.transforms))] - lines.append(f'(base): {self.base_dist}') + lines = [f'({i + 1}): {t}' for i, t in enumerate(self.transforms)] + lines.append(f'(base): {self.base}') lines = indent('\n'.join(lines), ' ') return self.__class__.__name__ + '(\n' + lines + '\n)' - def expand(self, batch_shape: Size, new: Distribution = None) -> Distribution: + @property + def batch_shape(self) -> Size: + return self.base.batch_shape + + @property + def event_shape(self) -> Size: + shape = self.base.event_shape + + for t in reversed(self.transforms): + shape = t.inverse_shape(shape) + + return shape + + def expand(self, batch_shape: Size, new: Distribution = None): new = self._get_checked_instance(NormalizingFlow, new) - return super().expand(batch_shape, new) + new.transforms = self.transforms + new.base = self.base.expand(batch_shape) + + Distribution.__init__(new, batch_shape=batch_shape, validate_args=False) + + return new + + def log_prob(self, x: Tensor) -> Tensor: + acc = 0 + event_dim = len(self.event_shape) + + for t in self.transforms: + x, ladj = t.call_and_ladj(x) + acc = acc + _sum_rightmost(ladj, event_dim - t.domain.event_dim) + event_dim += t.codomain.event_dim - t.domain.event_dim + + return self.base.log_prob(x) + acc + + @property + def has_rsample(self) -> bool: + return self.base.has_rsample + + def rsample(self, shape: Size = ()): + x = self.base.rsample(shape) + + for t in reversed(self.transforms): + x = t.inv(x) + + return x + + def sample(self, shape: Size = ()): + with torch.no_grad(): + x = self.base.sample(shape) + + for t in reversed(self.transforms): + x = t.inv(x) + + return x class Joint(Distribution): diff --git a/zuko/flows.py b/zuko/flows.py index 3c9f085..bf1681d 100644 --- a/zuko/flows.py +++ b/zuko/flows.py @@ -11,6 +11,8 @@ 'NeuralAutoregressiveTransform', 'UnconstrainedNeuralAutoregressiveTransform', 'NAF', + 'FreeFormJacobianTransform', + 'CNF', ] import abc @@ -24,7 +26,7 @@ from .distributions import * from .transforms import * -from .nn import MLP, MaskedMLP, MonotonicMLP +from .nn import * from .utils import broadcast @@ -38,7 +40,7 @@ def forward(y: Tensor = None) -> Distribution: y: A context :math:`y`. Returns: - A distribution :math:`p(x | y)`. + A distribution :math:`p(X | y)`. """ pass @@ -84,7 +86,7 @@ def forward(self, y: Tensor = None) -> NormalizingFlow: y: A context :math:`y`. Returns: - A normalizing flow :math:`p(x | y)`. + A normalizing flow :math:`p(X | y)`. """ transforms = [t(y) for t in self.transforms] @@ -204,6 +206,8 @@ def __init__( if order is None: order = torch.arange(features) + else: + order = torch.as_tensor(order) self.passes = min(max(passes, 1), features) self.order = torch.div(order, ceil(features / self.passes), rounding_mode='floor') @@ -219,7 +223,7 @@ def extra_repr(self) -> str: base = self.univariate(*map(torch.randn, self.shapes)) order = self.order.tolist() - if len(order) > 11: + if len(order) > 10: order = str(order[:5] + [...] + order[-5:]).replace('Ellipsis', '...') return '\n'.join([ @@ -231,15 +235,15 @@ def meta(self, y: Tensor, x: Tensor) -> Transform: if y is not None: x = torch.cat(broadcast(x, y, ignore=1), dim=-1) - params = self.hyper(x) - params = params.reshape(*params.shape[:-1], -1, sum(self.sizes)) - - args = params.split(self.sizes, dim=-1) - args = [a.reshape(a.shape[:-1] + s) for a, s in zip(args, self.shapes)] + phi = self.hyper(x) + phi = phi.unflatten(-1, (-1, sum(self.sizes))) + phi = phi.split(self.sizes, -1) + phi = (p.unflatten(-1, s + (1,)) for p, s in zip(phi, self.shapes)) + phi = (p.squeeze(-1) for p in phi) - return self.univariate(*args) + return self.univariate(*phi) - def forward(self, y: Tensor = None) -> AutoregressiveTransform: + def forward(self, y: Tensor = None) -> Transform: return AutoregressiveTransform(partial(self.meta, y), self.passes) @@ -484,7 +488,10 @@ def f(self, signal: Tensor, x: Tensor) -> Tensor: ).squeeze(dim=-1) def univariate(self, signal: Tensor) -> Transform: - return MonotonicTransform(partial(self.f, signal)) + return MonotonicTransform( + f=partial(self.f, signal), + phi=(signal, *self.network.parameters()), + ) class UnconstrainedNeuralAutoregressiveTransform(MaskedAutoregressiveTransform): @@ -564,7 +571,11 @@ def g(self, signal: Tensor, x: Tensor) -> Tensor: ).squeeze(dim=-1) def univariate(self, signal: Tensor, constant: Tensor) -> Transform: - return UnconstrainedMonotonicTransform(partial(self.g, signal), constant) + return UnconstrainedMonotonicTransform( + g=partial(self.g, signal), + C=constant, + phi=(signal, *self.integrand.parameters()), + ) class NAF(FlowModule): @@ -629,3 +640,113 @@ def __init__( ) super().__init__(transforms, base) + + +class FreeFormJacobianTransform(TransformModule): + r"""Creates a free-form Jacobian transformation. + + References: + | FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models (Grathwohl et al., 2018) + | https://arxiv.org/abs/1810.01367 + + Arguments: + features: The number of features. + context: The number of context features. + kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`. + + Example: + >>> t = FreeFormJacobianTranform(3, 4) + >>> t + FreeFormJacobianTranform( + (time): 1.000 + (ode): MLP( + (0): Linear(in_features=8, out_features=64, bias=True) + (1): ELU(alpha=1.0) + (2): Linear(in_features=64, out_features=64, bias=True) + (3): ELU(alpha=1.0) + (4): Linear(in_features=64, out_features=3, bias=True) + ) + ) + >>> x = torch.randn(3) + >>> x + tensor([ 0.1777, 1.0139, -1.0370]) + >>> y = torch.randn(4) + >>> z = t(y)(x) + >>> t(y).inv(z) + tensor([ 0.1777, 1.0139, -1.0370]) + """ + + def __init__( + self, + features: int, + context: int = 0, + **kwargs, + ): + super().__init__() + + kwargs.setdefault('activation', nn.ELU) + + self.ode = MLP(features + 1 + context, features, **kwargs) + self.log_t = nn.Parameter(torch.tensor(0.0)) + + def extra_repr(self) -> str: + return f'(time): {self.log_t.exp().item():.3f}' + + def f(self, y: Tensor, x: Tensor, t: Tensor) -> Tensor: + if y is None: + x = torch.cat(broadcast(x, t[..., None], ignore=1), dim=-1) + else: + x = torch.cat(broadcast(x, t[..., None], y, ignore=1), dim=-1) + + return self.ode(x) + + def forward(self, y: Tensor = None) -> Transform: + return FFJTransform( + f=partial(self.f, y), + time=self.log_t.exp(), + phi=(y, *self.ode.parameters()), + ) + + +class CNF(FlowModule): + r"""Creates a continuous normalizing flow (CNF) with free-form Jacobian + transformations. + + References: + | Neural Ordinary Differential Equations (Chen el al., 2018) + | https://arxiv.org/abs/1806.07366 + + | FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models (Grathwohl et al., 2018) + | https://arxiv.org/abs/1810.01367 + + Arguments: + features: The number of features. + context: The number of context features. + transforms: The number of transformations. + kwargs: Keyword arguments passed to :class:`FreeFormJacobianTransform`. + """ + + def __init__( + self, + features: int, + context: int = 0, + transforms: int = 1, + **kwargs, + ): + transforms = [ + FreeFormJacobianTransform( + features=features, + context=context, + **kwargs, + ) + for _ in range(transforms) + ] + + base = Unconditional( + DiagNormal, + torch.zeros(features), + torch.ones(features), + buffer=True, + ) + + super().__init__(transforms, base) diff --git a/zuko/nn.py b/zuko/nn.py index b2d2a95..1c3667e 100644 --- a/zuko/nn.py +++ b/zuko/nn.py @@ -10,26 +10,30 @@ from typing import * -class BatchNorm0d(nn.BatchNorm1d): - r"""Creates a batch normalization (BatchNorm) layer for scalars. +class LayerNorm(nn.Module): + r"""Creates a normalization layer that standardizes features along a dimension. + + .. math:: y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} References: - | Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift (Ioffe et al., 2015) - | https://arxiv.org/abs/1502.03167 + Layer Normalization (Lei Ba et al., 2016) + https://arxiv.org/abs/1607.06450 Arguments: - args: Positional arguments passed to :class:`torch.nn.BatchNorm1d`. - kwargs: Keyword arguments passed to :class:`torch.nn.BatchNorm1d`. + dim: The dimension(s) to standardize. + eps: A numerical stability term. """ - def forward(self, x: Tensor) -> Tensor: - shape = x.shape + def __init__(self, dim: Union[int, Iterable[int]] = -1, eps: float = 1e-5): + super().__init__() - x = x.reshape(-1, shape[-1]) - x = super().forward(x) - x = x.reshape(shape) + self.dim = dim if type(dim) is int else tuple(dim) + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + variance, mean = torch.var_mean(x, unbiased=True, dim=self.dim, keepdim=True) - return x + return (x - mean) / (variance + self.eps).sqrt() class MLP(nn.Sequential): @@ -54,8 +58,7 @@ class MLP(nn.Sequential): hidden_features: The numbers of hidden features. activation: The activation function constructor. If :py:`None`, use :class:`torch.nn.ReLU` instead. - batchnorm: Whether to use batch normalization or not. - dropout: The dropout rate. + normalize: Whether features are normalized between layers or not. kwargs: Keyword arguments passed to :class:`torch.nn.Linear`. Example: @@ -76,15 +79,13 @@ def __init__( out_features: int, hidden_features: List[int] = [64, 64], activation: Callable[[], nn.Module] = None, - batchnorm: bool = False, - dropout: float = 0.0, + normalize: bool = False, **kwargs, ): if activation is None: activation = nn.ReLU - batchnorm = BatchNorm0d if batchnorm else lambda _: None - dropout = nn.Dropout(dropout) if dropout > 0 else None + normalization = LayerNorm if normalize else lambda: None layers = [] @@ -94,12 +95,11 @@ def __init__( ): layers.extend([ nn.Linear(before, after, **kwargs), - batchnorm(after), activation(), - dropout, + normalization(), ]) - layers = layers[:-3] + layers = layers[:-2] layers = filter(lambda l: l is not None, layers) super().__init__(*layers) @@ -123,16 +123,11 @@ def __init__(self, adjacency: BoolTensor, **kwargs): self.register_buffer('mask', adjacency) - degree = adjacency.sum(dim=-1) - rescale = adjacency.shape[-1] / torch.clip(degree, min=1) - - self.weight.data *= rescale[:, None] - def forward(self, x: Tensor) -> Tensor: return F.linear(x, self.mask * self.weight, self.bias) -class MaskedMLP(MLP): +class MaskedMLP(nn.Sequential): r"""Creates a masked multi-layer perceptron (MaskedMLP). The resulting MLP is a transformation :math:`y = f(x)` whose Jacobian entries @@ -140,8 +135,9 @@ class MaskedMLP(MLP): Arguments: adjacency: The adjacency matrix :math:`A \in \{0, 1\}^{M \times N}`. - args: Positional arguments passed to :class:`MLP`. - kwargs: Keyword arguments passed to :class:`MLP`. + hidden_features: The numbers of hidden features. + activation: The activation function constructor. If :py:`None`, use + :class:`torch.nn.ReLU` instead. Example: >>> adjacency = torch.randn(4, 3) < 0 @@ -167,33 +163,53 @@ class MaskedMLP(MLP): [ 0.0000, 0.0060, -0.0063]]) """ - def __init__(self, adjacency: BoolTensor, *args, **kwargs): - super().__init__(*reversed(adjacency.shape), *args, **kwargs) + def __init__( + self, + adjacency: BoolTensor, + hidden_features: List[int] = [64, 64], + activation: Callable[[], nn.Module] = None, + ): + out_features, in_features = adjacency.shape + + if activation is None: + activation = nn.ReLU # Merge outputs with the same dependencies adjacency, inverse = torch.unique(adjacency, dim=0, return_inverse=True) - # j precedes i if A_ik = 1 for all k such that A_jk = 1 + # P_ij = 1 if A_ik = 1 for all k such that A_jk = 1 precedence = adjacency.int() @ adjacency.int().t() == adjacency.sum(dim=-1) - for i, layer in enumerate(self): - if isinstance(layer, nn.Linear): - if i > 0: - mask = precedence[:, indices] - else: - mask = adjacency + # Layers + layers = [] + + for i, features in enumerate(hidden_features + [out_features]): + if i > 0: + mask = precedence[:, indices] + else: + mask = adjacency + + if (~mask).all(): + raise ValueError("The adjacency matrix leads to a null Jacobian.") - if (~mask).all(): - raise ValueError("The adjacency matrix leads to a null Jacobian.") + if i < len(hidden_features): + reachable = mask.sum(dim=-1).nonzero().squeeze(dim=-1) + indices = reachable[torch.arange(features) % len(reachable)] + mask = mask[indices] + else: + mask = mask[inverse] - if i < len(self) - 1: - reachable = mask.sum(dim=-1).nonzero().squeeze(dim=-1) - indices = reachable[torch.arange(layer.out_features) % len(reachable)] - mask = mask[indices] - else: - mask = mask[inverse] + layers.extend([ + MaskedLinear(adjacency=mask), + activation(), + ]) + + layers = layers[:-1] - self[i] = MaskedLinear(adjacency=mask) + super().__init__(*layers) + + self.in_features = in_features + self.out_features = out_features class MonotonicLinear(nn.Linear): @@ -258,7 +274,7 @@ class MonotonicMLP(MLP): def __init__(self, *args, **kwargs): kwargs['activation'] = nn.ELU - kwargs['batchnorm'] = False + kwargs['normalize'] = False super().__init__(*args, **kwargs) diff --git a/zuko/transforms.py b/zuko/transforms.py index f45f7ca..730299b 100644 --- a/zuko/transforms.py +++ b/zuko/transforms.py @@ -9,7 +9,23 @@ from torch.distributions import constraints from typing import * -from .utils import bisection, broadcast, gauss_legendre +from .utils import bisection, broadcast, gauss_legendre, odeint + + +torch.distributions.transforms._InverseTransform.__name__ = 'Inverse' + + +def _call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: + r"""Returns both the transformed value and the log absolute determinant of the + transformation's Jacobian.""" + + y = self._call(x) + ladj = self.log_abs_det_jacobian(x, y) + + return y, ladj + + +Transform.call_and_ladj = _call_and_ladj class IdentityTransform(Transform): @@ -249,27 +265,36 @@ def _inverse(self, y: Tensor) -> Tensor: return torch.where(mask, x, y) def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: + _, ladj = self.call_and_ladj(x) + return ladj + + def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: k = self.searchsorted(self.horizontal, x) - 1 mask, x0, x1, y0, y1, d0, d1, s = self.bin(k) z = mask * (x - x0) / (x1 - x0) + y = y0 + (y1 - y0) * (s * z**2 + d0 * z * (1 - z)) / ( + s + (d0 + d1 - 2 * s) * z * (1 - z) + ) + jacobian = ( s**2 * (2 * s * z * (1 - z) + d0 * (1 - z) ** 2 + d1 * z**2) / (s + (d0 + d1 - 2 * s) * z * (1 - z)) ** 2 ) - return mask * jacobian.log() + return torch.where(mask, y, x), mask * jacobian.log() class MonotonicTransform(Transform): - r"""Creates a transformation from a monotonic univariate function :math:`f(x)`. + r"""Creates a transformation from a monotonic univariate function :math:`f_\phi(x)`. - The inverse function :math:`f^{-1}` is approximated using the bisection method. + The inverse function :math:`f_\phi^{-1}` is approximated using the bisection method. Arguments: - f: A monotonic univariate function :math:`f(x)`. + f: A monotonic univariate function :math:`f_\phi`. + phi: The parameters :math:`\phi` of :math:`f_\phi`. bound: The domain bound :math:`B`. eps: The absolute tolerance for the inverse transformation. """ @@ -282,6 +307,7 @@ class MonotonicTransform(Transform): def __init__( self, f: Callable[[Tensor], Tensor], + phi: Iterable[Tensor] = (), bound: float = 5.0, eps: float = 1e-6, **kwargs, @@ -289,6 +315,7 @@ def __init__( super().__init__(**kwargs) self.f = f + self.phi = tuple(filter(lambda p: p.requires_grad, phi)) self.bound = bound self.eps = eps @@ -297,20 +324,26 @@ def _call(self, x: Tensor) -> Tensor: def _inverse(self, y: Tensor) -> Tensor: return bisection( - f=lambda x: self.f(x) - y, + f=self.f, + y=y, a=torch.full_like(y, -self.bound), b=torch.full_like(y, self.bound), n=math.ceil(math.log2(2 * self.bound / self.eps)), + phi=self.phi, ) def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: - return torch.log( - torch.autograd.functional.jacobian( - func=lambda x: self.f(x).sum(), - inputs=x, - create_graph=True, - ) - ) + _, ladj = self.call_and_ladj(x) + return ladj + + def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: + with torch.enable_grad(): + x = x.requires_grad_() + y = self.f(x) + + jacobian = torch.autograd.grad(y, x, torch.ones_like(y), create_graph=True)[0] + + return y, jacobian.log() class UnconstrainedMonotonicTransform(MonotonicTransform): @@ -322,7 +355,7 @@ class UnconstrainedMonotonicTransform(MonotonicTransform): The definite integral is estimated by a :math:`n`-point Gauss-Legendre quadrature. Arguments: - g: A positive univariate function :math:`g(x)`. + g: A positive univariate function :math:`g`. C: The integration constant :math:`C`. n: The number of points :math:`n` for the quadrature. kwargs: Keyword arguments passed to :class:`MonotonicTransform`. @@ -352,11 +385,15 @@ def f(self, x: Tensor) -> Tensor: a=torch.zeros_like(x), b=x, n=self.n, + phi=self.phi, ) + self.C def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: return self.g(x).log() + def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: + return self.f(x), self.g(x).log() + class SOSPolynomialTransform(UnconstrainedMonotonicTransform): r"""Creates a sum-of-squares (SOS) polynomial transformation. @@ -383,7 +420,7 @@ class SOSPolynomialTransform(UnconstrainedMonotonicTransform): sign = +1 def __init__(self, a: Tensor, C: Tensor, **kwargs): - super().__init__(self.g, C, a.shape[-1], **kwargs) + super().__init__(self.g, C, phi=(a,), n=a.shape[-1], **kwargs) self.a = a self.i = torch.arange(a.shape[-1]).to(a.device) @@ -396,8 +433,87 @@ def g(self, x: Tensor) -> Tensor: return p.squeeze(dim=-1).square().sum(dim=-1) +class FFJTransform(Transform): + r"""Creates a free-form Jacobian (FFJ) transformation. + + The transformation is the integration of a system of first-order ordinary + differential equations + + .. math:: x(T) = \int_0^T f_\phi(x(t), t) ~ dt . + + The log-determinant of the Jacobian is replaced by an unbiased stochastic + linear-time estimate. + + References: + | FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models (Grathwohl et al., 2018) + | https://arxiv.org/abs/1810.01367 + + Arguments: + f: A system of first-order ODEs :math:`f_\phi`. + time: The integration time :math:`T`. + phi: The parameters :math:`\phi` of :math:`f_\phi`. + """ + + domain = constraints.real_vector + codomain = constraints.real_vector + bijective = True + + def __init__( + self, + f: Callable[[Tensor, Tensor], Tensor], + time: Tensor, + phi: Iterable[Tensor] = (), + **kwargs, + ): + super().__init__(**kwargs) + + self.f = f + self.t0 = time.new_tensor(0.0) + self.t1 = time + self.phi = tuple(filter(lambda p: p.requires_grad, phi)) + + def _call(self, x: Tensor) -> Tensor: + return odeint(self.f, x, self.t0, self.t1, self.phi) + + def _inverse(self, y: Tensor) -> Tensor: + return odeint(self.f, y, self.t1, self.t0, self.phi) + + def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: + _, ladj = self.call_and_ladj(x) + return ladj + + def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: + shape = x.shape + size = x.numel() + + eps = torch.randn_like(x) + + def f_aug(x_aug: Tensor, t: Tensor) -> Tensor: + x = x_aug[:size].reshape(shape) + + with torch.enable_grad(): + x = x.requires_grad_() + dx = self.f(x, t) + + epsjp = torch.autograd.grad(dx, x, eps, create_graph=True)[0] + trace = (epsjp * eps).sum(dim=-1) + + return torch.cat((dx.flatten(), trace.flatten())) + + zeros = x.new_zeros(shape[:-1]) + + x_aug = torch.cat((x.flatten(), zeros.flatten())) + y_aug = odeint(f_aug, x_aug, self.t0, self.t1, self.phi) + + y, score = y_aug[:size], y_aug[size:] + + return y.reshape(shape), score.reshape(shape[:-1]) + + class AutoregressiveTransform(Transform): - r"""Transform via an autoregressive mapping. + r"""Transform via an autoregressive scheme. + + .. math:: y_i = f(x_i; x_{ Tensor: - _x, _f = self._cache - - if x is _x: - f = _f - else: - f = self.meta(x) - - self._cache = x, f - - return f(x) + return self.meta(x)(x) def _inverse(self, y: Tensor) -> Tensor: x = torch.zeros_like(y) @@ -441,16 +546,11 @@ def _inverse(self, y: Tensor) -> Tensor: return x def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: - _x, _f = self._cache + return self.meta(x).log_abs_det_jacobian(x, y).sum(dim=-1) - if x is _x: - f = _f - else: - f = self.meta(x) - - self._cache = x, f - - return f.log_abs_det_jacobian(x, y).sum(dim=-1) + def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]: + y, ladj = self.meta(x).call_and_ladj(x) + return y, ladj.sum(dim=-1) class PermutationTransform(Transform): @@ -468,16 +568,20 @@ def __init__(self, order: LongTensor, **kwargs): super().__init__(**kwargs) self.order = order - self.inverse = torch.argsort(order) def __repr__(self) -> str: - return f'{self.__class__.__name__}({self.order.tolist()})' + order = self.order.tolist() + + if len(order) > 10: + order = str(order[:5] + [...] + order[-5:]).replace('Ellipsis', '...') + + return f'{self.__class__.__name__}({order})' def _call(self, x: Tensor) -> Tensor: return x[..., self.order] def _inverse(self, y: Tensor) -> Tensor: - return y[..., self.inverse] + return y[..., torch.argsort(self.order)] def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor: return x.new_zeros(x.shape[:-1]) diff --git a/zuko/utils.py b/zuko/utils.py index 3d473b6..843cce0 100644 --- a/zuko/utils.py +++ b/zuko/utils.py @@ -1,6 +1,8 @@ r"""General purpose helpers.""" -__all__ = ['bisection', 'broadcast', 'gauss_legendre'] +from __future__ import annotations + +__all__ = ['bisection', 'broadcast', 'gauss_legendre', 'odeint'] import numpy as np import torch @@ -13,40 +15,86 @@ def bisection( f: Callable[[Tensor], Tensor], - a: Tensor, - b: Tensor, + y: Tensor, + a: Union[float, Tensor], + b: Union[float, Tensor], n: int = 16, + phi: Iterable[Tensor] = (), ) -> Tensor: - r"""Applies the bisection method to find a root :math:`x` of a function - :math:`f(x)` between the bounds :math:`a` an :math:`b`. + r"""Applies the bisection method to find :math:`x` between the bounds :math:`a` + an :math:`b` such that :math:`f_\phi(x)` is close to :math:`y`. + + Gradients are propagated through :math:`y` and :math:`\phi` via implicit + differentiation. Wikipedia: https://wikipedia.org/wiki/Bisection_method Arguments: - f: A univariate function :math:`f(x)`. - a: The bound :math:`a` such that :math:`f(a) \leq 0`. - b: The bound :math:`b` such that :math:`0 \leq f(b)`. + f: A univariate function :math:`f_\phi`. + y: The target :math:`y`. + a: The bound :math:`a` such that :math:`f_\phi(a) \leq y`. + b: The bound :math:`b` such that :math:`y \leq f_\phi(b)`. n: The number of iterations. + phi: The parameters :math:`\phi` of :math:`f_\phi`. Example: >>> f = torch.cos - >>> a = torch.tensor(2.0) - >>> b = torch.tensor(1.0) - >>> bisection(f, a, b, n=16) + >>> y = torch.tensor(0.0) + >>> bisection(f, y, 2.0, 1.0, n=16) tensor(1.5708) """ - with torch.no_grad(): + a = torch.as_tensor(a).to(y) + b = torch.as_tensor(b).to(y) + + return Bisection.apply(f, y, a, b, n, *phi) + + +class Bisection(torch.autograd.Function): + @staticmethod + def forward( + ctx, + f: Callable[[Tensor], Tensor], + y: Tensor, + a: Tensor, + b: Tensor, + n: int, + *phi: Tensor, + ) -> Tensor: + ctx.f = f + ctx.save_for_backward(*phi) + for _ in range(n): c = (a + b) / 2 - mask = f(c) < 0 + mask = f(c) < y a = torch.where(mask, c, a) b = torch.where(mask, b, c) - return (a + b) / 2 + ctx.x = (a + b) / 2 + + return ctx.x + + @staticmethod + def backward(ctx, grad_x: Tensor) -> Tuple[Tensor, ...]: + f, x = ctx.f, ctx.x + phi = ctx.saved_tensors + + with torch.enable_grad(): + x = x.detach().requires_grad_() + y = f(x) + + jacobian = torch.autograd.grad(y, x, torch.ones_like(y), retain_graph=True)[0] + grad_y = grad_x / jacobian + + if phi: + grad_phi = torch.autograd.grad(y, phi, -grad_y, retain_graph=True) + else: + grad_phi = () + + return (None, grad_y, None, None, None, *grad_phi) def broadcast(*tensors: Tensor, ignore: Union[int, List[int]] = 0) -> List[Tensor]: @@ -78,26 +126,57 @@ def broadcast(*tensors: Tensor, ignore: Union[int, List[int]] = 0) -> List[Tenso return [torch.broadcast_to(t, common + t.shape[i:]) for t, i in zip(tensors, dims)] -class AttachLimits(torch.autograd.Function): - r"""Attaches the limits of integration to the computational graph.""" +def gauss_legendre( + f: Callable[[Tensor], Tensor], + a: Tensor, + b: Tensor, + n: int = 3, + phi: Iterable[Tensor] = (), +) -> Tensor: + r"""Estimates the definite integral of a function :math:`f_\phi(x)` from :math:`a` + to :math:`b` using a :math:`n`-point Gauss-Legendre quadrature. + + .. math:: \int_a^b f_\phi(x) ~ dx \approx (b - a) \sum_{i = 1}^n w_i f_\phi(x_i) + + Wikipedia: + https://wikipedia.org/wiki/Gauss-Legendre_quadrature + + Arguments: + f: A univariate function :math:`f_\phi`. + a: The lower limit :math:`a`. + b: The upper limit :math:`b`. + n: The number of points :math:`n` at which the function is evaluated. + phi: The parameters :math:`\phi` of :math:`f_\phi`. + + Example: + >>> f = lambda x: torch.exp(-x**2) + >>> a, b = torch.tensor([-0.69, 4.2]) + >>> gauss_legendre(f, a, b, n=16) + tensor(1.4807) + """ + + return GaussLegendre.apply(f, a, b, n, *phi) + +class GaussLegendre(torch.autograd.Function): @staticmethod def forward( ctx, f: Callable[[Tensor], Tensor], a: Tensor, b: Tensor, - area: Tensor, + n: int, + *phi: Tensor, ) -> Tensor: - ctx.f = f - ctx.save_for_backward(a, b) + ctx.f, ctx.n = f, n + ctx.save_for_backward(a, b, *phi) - return area + return GaussLegendre.quadrature(f, a, b, n) @staticmethod def backward(ctx, grad_area: Tensor) -> Tuple[Tensor, ...]: - f = ctx.f - a, b = ctx.saved_tensors + f, n = ctx.f, ctx.n + a, b, *phi = ctx.saved_tensors if ctx.needs_input_grad[1]: grad_a = -f(a) * grad_area @@ -109,74 +188,249 @@ def backward(ctx, grad_area: Tensor) -> Tuple[Tensor, ...]: else: grad_b = None - return None, grad_a, grad_b, grad_area + if phi: + with torch.enable_grad(): + area = GaussLegendre.quadrature(f, a.detach(), b.detach(), n) + grad_phi = torch.autograd.grad(area, phi, grad_area, retain_graph=True) + else: + grad_phi = () -def gauss_legendre( - f: Callable[[Tensor], Tensor], - a: Tensor, - b: Tensor, - n: int = 3, + return (None, grad_a, grad_b, None, *grad_phi) + + @staticmethod + @lru_cache(maxsize=None) + def nodes(n: int, **kwargs) -> Tuple[Tensor, Tensor]: + r"""Returns the nodes and weights for a :math:`n`-point Gauss-Legendre + quadrature over the interval :math:`[0, 1]`. + + See :func:`numpy.polynomial.legendre.leggauss`. + """ + + nodes, weights = np.polynomial.legendre.leggauss(n) + + nodes = (nodes + 1) / 2 + weights = weights / 2 + + kwargs.setdefault('dtype', torch.get_default_dtype()) + + return ( + torch.as_tensor(nodes, **kwargs), + torch.as_tensor(weights, **kwargs), + ) + + @staticmethod + def quadrature( + f: Callable[[Tensor], Tensor], + a: Tensor, + b: Tensor, + n: int, + ) -> Tensor: + nodes, weights = GaussLegendre.nodes(n, dtype=a.dtype, device=a.device) + nodes = torch.lerp( + a[..., None], + b[..., None], + nodes, + ).movedim(-1, 0) + + return (b - a) * torch.tensordot(weights, f(nodes), dims=1) + + +def odeint( + f: Callable[[Tensor, Tensor], Tensor], + x: Tensor, + t0: Union[float, Tensor], + t1: Union[float, Tensor], + phi: Iterable[Tensor] = (), ) -> Tensor: - r"""Estimates the definite integral of :math:`f` from :math:`a` to :math:`b` - using a :math:`n`-point Gauss-Legendre quadrature. + r"""Integrates a system of first-order ordinary differential equations (ODEs) - .. math:: \int_a^b f(x) ~ dx \approx (b - a) \sum_{i = 1}^n w_i f(x_i) + .. math:: \frac{\mathrm{d} x}{\mathrm{d} t} = f_\phi(x, t) , - Wikipedia: - https://wikipedia.org/wiki/Gauss-Legendre_quadrature + from :math:`t_0` to :math:`t_1` using the adaptive Dormand-Prince method. The + output is the final state + + .. math:: x(t_1) = x_0 + \int_{t_0}^{t_1} f_\phi(x(t), t) ~ dt . + + Gradients are propagated through :math:`x_0`, :math:`t_0`, :math:`t_1` and + :math:`\phi` via the adaptive checkpoint adjoint (ACA) method. + + References: + | Neural Ordinary Differential Equations (Chen el al., 2018) + | https://arxiv.org/abs/1806.07366 + + | Adaptive Checkpoint Adjoint Method for Gradient Estimation in Neural ODE (Zhuang et al., 2020) + | https://arxiv.org/abs/2006.02493 Arguments: - f: A univariate function :math:`f(x)`. - a: The lower limit :math:`a`. - b: The upper limit :math:`b`. - n: The number of points :math:`n` at which the function is evaluated. + f: A system of first-order ODEs :math:`f_\phi`. + x: The initial state :math:`x_0`. + t0: The initial integration time :math:`t_0`. + t1: The final integration time :math:`t_1`. + phi: The parameters :math:`\phi` of :math:`f_\phi`. Example: - >>> f = lambda x: torch.exp(-x**2) - >>> a, b = torch.tensor([-0.69, 4.2]) - >>> gauss_legendre(f, a, b, n=16) - tensor(1.4807) + >>> A = torch.randn(3, 3) + >>> f = lambda x, t: x @ A + >>> x0 = torch.randn(3) + >>> x1 = odeint(f, x0, 0.0, 1.0) + >>> x1 + tensor([-3.7454, -0.4140, 0.2677]) """ - nodes, weights = leggauss(n, dtype=a.dtype, device=a.device) - nodes = torch.lerp( - a[..., None].detach(), - b[..., None].detach(), - nodes, - ).movedim(-1, 0) + t0 = torch.as_tensor(t0).to(x) + t1 = torch.as_tensor(t1).to(x) - area = (b - a).detach() * torch.tensordot(weights, f(nodes), dims=1) + return AdaptiveCheckpointAdjoint.apply(f, x, t0, t1, *phi) - return AttachLimits.apply(f, a, b, area) +def dopri45( + f: Callable[[Tensor, Tensor], Tensor], + x: Tensor, + t: Tensor, + dt: Tensor, + error: bool = False, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + r"""Applies one step of the Dormand-Prince method. -@lru_cache(maxsize=None) -def leggauss(n: int, **kwargs) -> Tuple[Tensor, Tensor]: - r"""Returns the nodes and weights for a :math:`n`-point Gauss-Legendre - quadrature over the interval :math:`[0, 1]`. + Wikipedia: + https://wikipedia.org/wiki/Dormand-Prince_method + """ - See :func:`numpy.polynomial.legendre.leggauss`. + k1 = dt * f(x, t) + k2 = dt * f(x + 1 / 5 * k1, t + 1 / 5 * dt) + k3 = dt * f(x + 3 / 40 * k1 + 9 / 40 * k2, t + 3 / 10 * dt) + k4 = dt * f(x + 44 / 45 * k1 - 56 / 15 * k2 + 32 / 9 * k3, t + 4 / 5 * dt) + k5 = dt * f( + x + 19372 / 6561 * k1 - 25360 / 2187 * k2 + 64448 / 6561 * k3 - 212 / 729 * k4, + t + 8 / 9 * dt, + ) + k6 = dt * f( + x + + 9017 / 3168 * k1 + - 355 / 33 * k2 + + 46732 / 5247 * k3 + + 49 / 176 * k4 + - 5103 / 18656 * k5, + t + dt, + ) + x_next = ( + x + + 35 / 384 * k1 + + 500 / 1113 * k3 + + 125 / 192 * k4 + - 2187 / 6784 * k5 + + 11 / 84 * k6 + ) - Arguments: - n: The number of points :math:`n`. + if not error: + return x_next + + k7 = dt * f(x_next, t + dt) + x_star = ( + x + + 5179 / 57600 * k1 + + 7571 / 16695 * k3 + + 393 / 640 * k4 + - 92097 / 339200 * k5 + + 187 / 2100 * k6 + + 1 / 40 * k7 + ) - Example: - >>> nodes, weights = leggauss(3) - >>> nodes - tensor([0.1127, 0.5000, 0.8873]) - >>> weights - tensor([0.2778, 0.4444, 0.2778]) + return x_next, abs(x_next - x_star) + + +class NestedTensor(tuple): + r"""Creates an efficient data-structure to hold and perform basic operations on + lists of tensors. """ - nodes, weights = np.polynomial.legendre.leggauss(n) + def __new__(cls, tensors: Iterable[Tensor] = ()) -> NestedTensor: + return tuple.__new__(cls, tensors) - nodes = (nodes + 1) / 2 - weights = weights / 2 + def __add__(self, other: NestedTensor) -> NestedTensor: + return NestedTensor(x + y for x, y in zip(self, other)) - kwargs.setdefault('dtype', torch.get_default_dtype()) + def __sub__(self, other: NestedTensor) -> NestedTensor: + return NestedTensor(x - y for x, y in zip(self, other)) - return ( - torch.as_tensor(nodes, **kwargs), - torch.as_tensor(weights, **kwargs), - ) + def __rmul__(self, factor: Tensor) -> NestedTensor: + return NestedTensor(factor * x for x in self) + + def __abs__(self) -> NestedTensor: + return NestedTensor(map(abs, self)) + + +class AdaptiveCheckpointAdjoint(torch.autograd.Function): + @staticmethod + def forward( + ctx, + f: Callable[[Tensor, Tensor], Tensor], + x: Tensor, + t0: Tensor, + t1: Tensor, + *phi: Tensor, + ) -> Tensor: + ctx.f = f + ctx.save_for_backward(x, t0, t1, *phi) + ctx.steps = [] + + t, dt = t0, t1 - t0 + sign = torch.sign(dt) + + while sign * (t1 - t) > 0: + dt = sign * torch.min(abs(dt), abs(t1 - t)) + + while True: + y, error = dopri45(f, x, t, dt, error=True) + tolerance = 1e-6 + 1e-5 * torch.max(abs(x), abs(y)) + error = torch.max(error / tolerance).item() + 1e-6 + + if error < 1.0: + x, t = y, t + dt + ctx.steps.append((x, t, dt)) + + dt = dt * min(10.0, max(0.1, 0.9 / error ** (1 / 5))) + + if error < 1.0: + break + + return x + + @staticmethod + def backward(ctx, grad_x: Tensor) -> Tuple[Tensor, ...]: + f = ctx.f + x0, t0, t1, *phi = ctx.saved_tensors + x1, _, _ = ctx.steps[-1] + + # Final time + if ctx.needs_input_grad[3]: + grad_t1 = f(x1, t1) * grad_x + else: + grad_t1 = None + + # Adjoint + grad_phi = tuple(map(torch.zeros_like, phi)) + + def g(x: NestedTensor, t: Tensor) -> NestedTensor: + x, grad_x, *_ = x + + with torch.enable_grad(): + x = x.detach().requires_grad_() + dx = f(x, t) + + grad_x, *grad_phi = torch.autograd.grad(dx, (x, *phi), -grad_x, retain_graph=True) + + return NestedTensor((dx, grad_x, *grad_phi)) + + for x, t, dt in reversed(ctx.steps): + x = NestedTensor((x, grad_x, *grad_phi)) + x, grad_x, *grad_phi = dopri45(g, x, t, -dt) + + # Initial time + if ctx.needs_input_grad[2]: + grad_t0 = f(x0, t0) * grad_x + else: + grad_t0 = None + + return (None, grad_x, grad_t0, grad_t1, *grad_phi)