Skip to content

Commit

Permalink
♻️ Refactor NormalizingFlow class
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Dec 31, 2022
1 parent a755a69 commit 3e9c68d
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 57 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setuptools.setup(
name='zuko',
version='0.0.7',
version='0.0.8',
packages=setuptools.find_packages(),
description='Normalizing flows in PyTorch',
keywords=[
Expand Down
3 changes: 2 additions & 1 deletion tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import pytest
import torch

from torch.distributions import *
from zuko.distributions import *


def test_distributions():
ds = [
NormalizingFlow([ExpTransform()], Gamma(2.0, 1.0)),
NormalizingFlow(ExpTransform(), Gamma(2.0, 1.0)),
Joint(Uniform(0.0, 1.0), Normal(0.0, 1.0)),
GeneralizedNormal(2.0),
DiagNormal(torch.zeros(2), torch.ones(2)),
Expand Down
1 change: 1 addition & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from torch import randn
from torch.distributions import *
from zuko.transforms import *


Expand Down
92 changes: 43 additions & 49 deletions zuko/distributions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
r"""Parameterizable probability distributions."""

__all__ = [
'NormalizingFlow',
'Joint',
'GeneralizedNormal',
'DiagNormal',
'BoxUniform',
'TransformedUniform',
'Truncated',
'Sort',
'TopK',
'Minimum',
'Maximum',
]

import math
import torch

Expand All @@ -17,14 +31,12 @@

class NormalizingFlow(Distribution):
r"""Creates a normalizing flow for a random variable :math:`X` towards a base
distribution :math:`p(Z)` through a series of :math:`n` invertible and differentiable
transformations :math:`f_1, f_2, \dots, f_n`.
distribution :math:`p(Z)` through a transformation :math:`f`.
The density of a realization :math:`x` is given by the change of variables
.. math:: p(X = x) = p(Z = f(x)) \left| \det \frac{\partial f(x)}{\partial x} \right|
.. math:: p(X = x) = p(Z = f(x)) \left| \det \frac{\partial f(x)}{\partial x} \right| .
where :math:`f = f_1 \circ \dots \circ f_n` is the transformations' composition.
To sample from :math:`p(X)`, realizations :math:`z \sim p(Z)` are mapped through
the inverse transformation :math:`g = f^{-1}`.
Expand All @@ -36,34 +48,38 @@ class NormalizingFlow(Distribution):
| https://arxiv.org/abs/1912.02762
Arguments:
transforms: A list of transformations :math:`f_i`.
transforms: A transformation :math:`f`.
base: A base distribution :math:`p(Z)`.
Example:
>>> d = NormalizingFlow([ExpTransform()], Gamma(2.0, 1.0))
>>> d = NormalizingFlow(ExpTransform(), Gamma(2.0, 1.0))
>>> d.sample()
tensor(1.1316)
"""

has_rsample = True

def __init__(
self,
transforms: List[Transform],
transform: Transform,
base: Distribution,
):
super().__init__()

codomain_dim = ComposeTransform(transforms).codomain.event_dim
reinterpreted = codomain_dim - len(base.event_shape)
reinterpreted = transform.codomain.event_dim - len(base.event_shape)

if reinterpreted > 0:
base = Independent(base, reinterpreted)

self.transforms = transforms
self.transform = transform
self.base = base
self.reinterpreted = max(-reinterpreted, 0)

def __repr__(self) -> str:
lines = [f'({i + 1}): {t}' for i, t in enumerate(self.transforms)]
lines.append(f'(base): {self.base}')
lines = [
f'(transform): {self.transform}',
f'(base): {self.base}',
]
lines = indent('\n'.join(lines), ' ')

return self.__class__.__name__ + '(\n' + lines + '\n)'
Expand All @@ -74,53 +90,31 @@ def batch_shape(self) -> Size:

@property
def event_shape(self) -> Size:
shape = self.base.event_shape

for t in reversed(self.transforms):
shape = t.inverse_shape(shape)

return shape
return self.transform.inverse_shape(self.base.event_shape)

def expand(self, batch_shape: Size, new: Distribution = None):
new = self._get_checked_instance(NormalizingFlow, new)
new.transforms = self.transforms
new.transform = self.transform
new.base = self.base.expand(batch_shape)
new.reinterpreted = self.reinterpreted

Distribution.__init__(new, batch_shape=batch_shape, validate_args=False)
Distribution.__init__(new, validate_args=False)

return new

def log_prob(self, x: Tensor) -> Tensor:
acc = 0
event_dim = len(self.event_shape)

for t in self.transforms:
x, ladj = t.call_and_ladj(x)
acc = acc + _sum_rightmost(ladj, event_dim - t.domain.event_dim)
event_dim += t.codomain.event_dim - t.domain.event_dim
z, ladj = self.transform.call_and_ladj(x)
ladj = _sum_rightmost(ladj, self.reinterpreted)

return self.base.log_prob(x) + acc

@property
def has_rsample(self) -> bool:
return self.base.has_rsample
return self.base.log_prob(z) + ladj

def rsample(self, shape: Size = ()):
x = self.base.rsample(shape)

for t in reversed(self.transforms):
x = t.inv(x)

return x

def sample(self, shape: Size = ()):
with torch.no_grad():
x = self.base.sample(shape)

for t in reversed(self.transforms):
x = t.inv(x)
if self.base.has_rsample:
z = self.base.rsample(shape)
else:
z = self.base.sample(shape)

return x
return self.transform.inv(z)


class Joint(Distribution):
Expand Down Expand Up @@ -334,7 +328,7 @@ class TransformedUniform(NormalizingFlow):
"""

def __init__(self, f: Transform, lower: Tensor, upper: Tensor):
super().__init__([f], Uniform(*map(f, map(torch.as_tensor, (lower, upper)))))
super().__init__(f, Uniform(*map(f, map(torch.as_tensor, (lower, upper)))))

def expand(self, batch_shape: Size, new: Distribution = None) -> Distribution:
new = self._get_checked_instance(TransformedUniform, new)
Expand Down Expand Up @@ -372,7 +366,7 @@ def __init__(
):
super().__init__(batch_shape=base.batch_shape)

assert len(base.event_shape) < 1, "'base' has to be univariate"
assert not base.event_shape, "'base' has to be univariate"

self.base = base
self.uniform = Uniform(base.cdf(lower), base.cdf(upper))
Expand Down Expand Up @@ -430,7 +424,7 @@ def __init__(
):
super().__init__(batch_shape=base.batch_shape)

assert len(base.event_shape) < 1, "'base' has to be univariate"
assert not base.event_shape, "'base' has to be univariate"

self.base = base
self.n = n
Expand Down
5 changes: 3 additions & 2 deletions zuko/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from functools import partial
from math import ceil
from torch import Tensor, LongTensor, Size
from torch.distributions import *
from typing import *

from .distributions import *
Expand Down Expand Up @@ -89,14 +90,14 @@ def forward(self, y: Tensor = None) -> NormalizingFlow:
A normalizing flow :math:`p(X | y)`.
"""

transforms = [t(y) for t in self.transforms]
transform = ComposedTransform(*(t(y) for t in self.transforms))

if y is None:
base = self.base(y)
else:
base = self.base(y).expand(y.shape[:-1])

return NormalizingFlow(transforms, base)
return NormalizingFlow(transform, base)


class Unconditional(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion zuko/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def forward(self, x: Tensor) -> Tensor:
class MLP(nn.Sequential):
r"""Creates a multi-layer perceptron (MLP).
Also known as fully connected feedforward network, an MLP is a series of
Also known as fully connected feedforward network, an MLP is a sequence of
non-linear parametric functions
.. math:: h_{i + 1} = a_{i + 1}(h_i W_{i + 1}^T + b_{i + 1}),
Expand Down
111 changes: 110 additions & 1 deletion zuko/transforms.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
r"""Parameterizable transformations."""

__all__ = [
'ComposedTransform',
'IdentityTransform',
'CosTransform',
'SinTransform',
'SoftclipTransform',
'MonotonicAffineTransform',
'MonotonicRQSTransform',
'MonotonicTransform',
'UnconstrainedMonotonicTransform',
'SOSPolynomialTransform',
'FFJTransform',
'AutoregressiveTransform',
'PermutationTransform',
]

import math
import torch
import torch.nn.functional as F

from torch import Tensor, LongTensor
from textwrap import indent
from torch import Tensor, LongTensor, Size
from torch.distributions import *
from torch.distributions import constraints
from torch.distributions.utils import _sum_rightmost
from typing import *

from .utils import bisection, broadcast, gauss_legendre, odeint
Expand All @@ -28,6 +46,97 @@ def _call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]:
Transform.call_and_ladj = _call_and_ladj


class ComposedTransform(Transform):
r"""Creates a transformation :math:`f(x) = f_n \circ \dots \circ f_0(x)`.
Arguments:
transforms: A sequence of transformations :math:`f_i`.
"""

def __init__(self, *transforms: Transform, **kwargs):
super().__init__(**kwargs)

assert transforms, "'transforms' cannot be empty"

event_dim = 0

for t in reversed(transforms):
event_dim = t.domain.event_dim + max(event_dim - t.codomain.event_dim, 0)

self.domain_dim = event_dim

for t in transforms:
event_dim += t.codomain.event_dim - t.domain.event_dim

self.codomain_dim = event_dim
self.transforms = transforms

def __repr__(self) -> str:
lines = [f'({i}): {t}' for i, t in enumerate(self.transforms)]
lines = indent('\n'.join(lines), ' ')

return f'{self.__class__.__name__}(\n' + lines + '\n)'

@property
def domain(self) -> constraints.Constraint:
domain = self.transforms[0].domain
reinterpreted = self.domain_dim - domain.event_dim

if reinterpreted > 0:
return constraints.independent(domain, reinterpreted)
else:
return domain

@property
def codomain(self) -> constraints.Constraint:
codomain = self.transforms[-1].codomain
reinterpreted = self.codomain_dim - codomain.event_dim

if reinterpreted > 0:
return constraints.independent(codomain, reinterpreted)
else:
return codomain

@property
def bijective(self) -> bool:
return all(t.bijective for t in self.transforms)

def _call(self, x: Tensor) -> Tensor:
for t in self.transforms:
x = t(x)
return x

def _inverse(self, y: Tensor) -> Tensor:
for t in reversed(self.transforms):
y = t.inv(y)
return y

def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor:
_, ladj = self.call_and_ladj(x)
return ladj

def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]:
event_dim = self.domain_dim
acc = 0

for t in self.transforms:
x, ladj = t.call_and_ladj(x)
acc = acc + _sum_rightmost(ladj, event_dim - t.domain.event_dim)
event_dim += t.codomain.event_dim - t.domain.event_dim

return x, acc

def forward_shape(self, shape: Size) -> Size:
for t in self.transforms:
shape = t.forward_shape(shape)
return shape

def inverse_shape(self, shape: Size) -> Size:
for t in reversed(self.transforms):
shape = t.inverse_shape(shape)
return shape


class IdentityTransform(Transform):
r"""Creates a transformation :math:`f(x) = x`."""

Expand Down
3 changes: 1 addition & 2 deletions zuko/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import numpy as np
import torch
import torch.nn as nn

from functools import lru_cache
from torch import Tensor
Expand All @@ -22,7 +21,7 @@ def bisection(
phi: Iterable[Tensor] = (),
) -> Tensor:
r"""Applies the bisection method to find :math:`x` between the bounds :math:`a`
an :math:`b` such that :math:`f_\phi(x)` is close to :math:`y`.
and :math:`b` such that :math:`f_\phi(x)` is close to :math:`y`.
Gradients are propagated through :math:`y` and :math:`\phi` via implicit
differentiation.
Expand Down

0 comments on commit 3e9c68d

Please sign in to comment.