Skip to content

Commit

Permalink
✨ New ComposedTransform transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Dec 30, 2022
1 parent a755a69 commit 997b9e6
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 48 deletions.
1 change: 1 addition & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import torch

from torch.distributions import *
from zuko.distributions import *


Expand Down
1 change: 1 addition & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from torch import randn
from torch.distributions import *
from zuko.transforms import *


Expand Down
90 changes: 45 additions & 45 deletions zuko/distributions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
r"""Parameterizable probability distributions."""

__all__ = [
'NormalizingFlow',
'Joint',
'GeneralizedNormal',
'DiagNormal',
'BoxUniform',
'TransformedUniform',
'Truncated',
'Sort',
'TopK',
'Minimum',
'Maximum',
]

import math
import torch

Expand All @@ -10,15 +24,17 @@
from torch.distributions.utils import _sum_rightmost
from typing import *

from .transforms import ComposedTransform


Distribution._validate_args = False
Distribution.arg_constraints = {}


class NormalizingFlow(Distribution):
r"""Creates a normalizing flow for a random variable :math:`X` towards a base
distribution :math:`p(Z)` through a series of :math:`n` invertible and differentiable
transformations :math:`f_1, f_2, \dots, f_n`.
distribution :math:`p(Z)` through a sequence of :math:`n` invertible and
differentiable transformations :math:`f_1, f_2, \dots, f_n`.
The density of a realization :math:`x` is given by the change of variables
Expand All @@ -36,7 +52,7 @@ class NormalizingFlow(Distribution):
| https://arxiv.org/abs/1912.02762
Arguments:
transforms: A list of transformations :math:`f_i`.
transforms: A sequence of transformations :math:`f_i`.
base: A base distribution :math:`p(Z)`.
Example:
Expand All @@ -45,25 +61,29 @@ class NormalizingFlow(Distribution):
tensor(1.1316)
"""

has_rsample = True

def __init__(
self,
transforms: List[Transform],
transforms: Iterable[Transform],
base: Distribution,
):
super().__init__()

codomain_dim = ComposeTransform(transforms).codomain.event_dim
reinterpreted = codomain_dim - len(base.event_shape)
transform = ComposedTransform(*transforms)
reinterpreted = transform.codomain_dim - len(base.event_shape)

if reinterpreted > 0:
base = Independent(base, reinterpreted)

self.transforms = transforms
self.transform = transform
self.base = base

def __repr__(self) -> str:
lines = [f'({i + 1}): {t}' for i, t in enumerate(self.transforms)]
lines.append(f'(base): {self.base}')
lines = [
f'(transform): {self.transform}',
f'(base): {self.base}',
]
lines = indent('\n'.join(lines), ' ')

return self.__class__.__name__ + '(\n' + lines + '\n)'
Expand All @@ -74,53 +94,33 @@ def batch_shape(self) -> Size:

@property
def event_shape(self) -> Size:
shape = self.base.event_shape

for t in reversed(self.transforms):
shape = t.inverse_shape(shape)

return shape
return self.transform.inverse_shape(self.base.event_shape)

def expand(self, batch_shape: Size, new: Distribution = None):
new = self._get_checked_instance(NormalizingFlow, new)
new.transforms = self.transforms
new.transform = self.transform
new.base = self.base.expand(batch_shape)

Distribution.__init__(new, batch_shape=batch_shape, validate_args=False)
Distribution.__init__(new, validate_args=False)

return new

def log_prob(self, x: Tensor) -> Tensor:
acc = 0
event_dim = len(self.event_shape)

for t in self.transforms:
x, ladj = t.call_and_ladj(x)
acc = acc + _sum_rightmost(ladj, event_dim - t.domain.event_dim)
event_dim += t.codomain.event_dim - t.domain.event_dim
z, ladj = self.transform.call_and_ladj(x)
ladj = _sum_rightmost(
ladj,
len(self.base.event_shape) - self.transform.codomain_dim,
)

return self.base.log_prob(x) + acc

@property
def has_rsample(self) -> bool:
return self.base.has_rsample
return self.base.log_prob(z) + ladj

def rsample(self, shape: Size = ()):
x = self.base.rsample(shape)

for t in reversed(self.transforms):
x = t.inv(x)

return x

def sample(self, shape: Size = ()):
with torch.no_grad():
x = self.base.sample(shape)

for t in reversed(self.transforms):
x = t.inv(x)
if self.base.has_rsample:
z = self.base.rsample(shape)
else:
z = self.base.sample(shape)

return x
return self.transform.inv(z)


class Joint(Distribution):
Expand Down Expand Up @@ -372,7 +372,7 @@ def __init__(
):
super().__init__(batch_shape=base.batch_shape)

assert len(base.event_shape) < 1, "'base' has to be univariate"
assert not base.event_shape, "'base' has to be univariate"

self.base = base
self.uniform = Uniform(base.cdf(lower), base.cdf(upper))
Expand Down Expand Up @@ -430,7 +430,7 @@ def __init__(
):
super().__init__(batch_shape=base.batch_shape)

assert len(base.event_shape) < 1, "'base' has to be univariate"
assert not base.event_shape, "'base' has to be univariate"

self.base = base
self.n = n
Expand Down
1 change: 1 addition & 0 deletions zuko/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from functools import partial
from math import ceil
from torch import Tensor, LongTensor, Size
from torch.distributions import *
from typing import *

from .distributions import *
Expand Down
2 changes: 1 addition & 1 deletion zuko/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def forward(self, x: Tensor) -> Tensor:
class MLP(nn.Sequential):
r"""Creates a multi-layer perceptron (MLP).
Also known as fully connected feedforward network, an MLP is a series of
Also known as fully connected feedforward network, an MLP is a sequence of
non-linear parametric functions
.. math:: h_{i + 1} = a_{i + 1}(h_i W_{i + 1}^T + b_{i + 1}),
Expand Down
111 changes: 110 additions & 1 deletion zuko/transforms.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
r"""Parameterizable transformations."""

__all__ = [
'ComposedTransform',
'IdentityTransform',
'CosTransform',
'SinTransform',
'SoftclipTransform',
'MonotonicAffineTransform',
'MonotonicRQSTransform',
'MonotonicTransform',
'UnconstrainedMonotonicTransform',
'SOSPolynomialTransform',
'FFJTransform',
'AutoregressiveTransform',
'PermutationTransform',
]

import math
import torch
import torch.nn.functional as F

from torch import Tensor, LongTensor
from textwrap import indent
from torch import Tensor, LongTensor, Size
from torch.distributions import *
from torch.distributions import constraints
from torch.distributions.utils import _sum_rightmost
from typing import *

from .utils import bisection, broadcast, gauss_legendre, odeint
Expand All @@ -28,6 +46,97 @@ def _call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]:
Transform.call_and_ladj = _call_and_ladj


class ComposedTransform(Transform):
r"""Creates a transformation :math:`f(x) = f_n \circ \dots \circ f_1(x)`.
Arguments:
transforms: A sequence of transformations :math:`f_i`.
"""

def __init__(self, *transforms: Transform, **kwargs):
super().__init__(**kwargs)

assert transforms, "'transforms' cannot be empty"

event_dim = 0

for t in reversed(transforms):
event_dim = t.domain.event_dim + max(event_dim - t.codomain.event_dim, 0)

self.domain_dim = event_dim

for t in transforms:
event_dim += t.codomain.event_dim - t.domain.event_dim

self.codomain_dim = event_dim
self.transforms = transforms

def __repr__(self) -> str:
lines = [f'({i + 1}): {t}' for i, t in enumerate(self.transforms)]
lines = indent('\n'.join(lines), ' ')

return f'{self.__class__.__name__}(\n' + lines + '\n)'

@property
def domain(self) -> constraints.Constraint:
domain = self.transforms[0].domain
reinterpreted = self.domain_dim - domain.event_dim

if reinterpreted > 0:
return constraints.independent(domain, reinterpreted)
else:
return domain

@property
def codomain(self) -> constraints.Constraint:
codomain = self.transforms[-1].codomain
reinterpreted = self.codomain_dim - codomain.event_dim

if reinterpreted > 0:
return constraints.independent(codomain, reinterpreted)
else:
return codomain

@property
def bijective(self) -> bool:
return all(t.bijective for t in self.transforms)

def _call(self, x: Tensor) -> Tensor:
for t in self.transforms:
x = t(x)
return x

def _inverse(self, y: Tensor) -> Tensor:
for t in reversed(self.transforms):
y = t.inv(y)
return y

def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor:
_, ladj = self.call_and_ladj(x)
return ladj

def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]:
event_dim = self.domain_dim
acc = 0

for t in self.transforms:
x, ladj = t.call_and_ladj(x)
acc = acc + _sum_rightmost(ladj, event_dim - t.domain.event_dim)
event_dim += t.codomain.event_dim - t.domain.event_dim

return x, acc

def forward_shape(self, shape: Size) -> Size:
for t in self.transforms:
shape = t.forward_shape(shape)
return shape

def inverse_shape(self, shape: Size) -> Size:
for t in reversed(self.transforms):
shape = t.inverse_shape(shape)
return shape


class IdentityTransform(Transform):
r"""Creates a transformation :math:`f(x) = x`."""

Expand Down
1 change: 0 additions & 1 deletion zuko/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import numpy as np
import torch
import torch.nn as nn

from functools import lru_cache
from torch import Tensor
Expand Down

0 comments on commit 997b9e6

Please sign in to comment.