Skip to content

Commit

Permalink
✨ New features
Browse files Browse the repository at this point in the history
* Add call_and_ladj method to transformations to improve flow efficiency
* Add continuous normalizing flow (CNF)
* Add free-form Jacobian (FFJ) transformation
* Add odeint solver
* Make bisection auto-differentiable
  • Loading branch information
francois-rozet committed Dec 22, 2022
1 parent 2dd249f commit 2d5a37f
Show file tree
Hide file tree
Showing 9 changed files with 818 additions and 194 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setuptools.setup(
name='zuko',
version='0.0.6',
version='0.0.7',
packages=setuptools.find_packages(),
description='Normalizing flows in PyTorch',
keywords=[
Expand Down
20 changes: 17 additions & 3 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def test_flows(tmp_path):
SOSPF(3, 5),
NAF(3, 5),
NAF(3, 5, unconstrained=True),
CNF(3, 5),
]

for flow in flows:
Expand All @@ -24,16 +25,27 @@ def test_flows(tmp_path):
assert log_p.shape == (256,), flow
assert log_p.requires_grad, flow

flow.zero_grad(set_to_none=True)
loss = -log_p.mean()
loss.backward()

for p in flow.parameters():
assert hasattr(p, 'grad'), flow
assert p.grad is not None, flow

# Sampling
z = flow(y).sample((32,))
x = flow(y).sample((32,))

assert z.shape == (32, 3), flow
assert x.shape == (32, 3), flow

# Reparameterization trick
x = flow(y).rsample()

flow.zero_grad(set_to_none=True)
loss = x.square().sum().sqrt()
loss.backward()

for p in flow.parameters():
assert p.grad is not None, flow

# Invertibility
x, y = randn(256, 3), randn(256, 5)
Expand All @@ -58,7 +70,9 @@ def test_flows(tmp_path):

x, y = randn(3), randn(5)

seed = torch.seed()
log_p = flow(y).log_prob(x)
torch.manual_seed(seed)
log_p_bis = flow_bis(y).log_prob(x)

assert torch.allclose(log_p, log_p_bis), flow
Expand Down
38 changes: 36 additions & 2 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ def test_univariate_transforms():
]

for t in ts:
# Call
if hasattr(t.domain, 'lower_bound'):
x = torch.linspace(t.domain.lower_bound, t.domain.upper_bound, 256)
x = torch.linspace(t.domain.lower_bound + 1e-2, t.domain.upper_bound - 1e-2, 256)
else:
x = torch.linspace(-5.0, 5.0, 256)

y = t(x)

assert x.shape == y.shape, t

# Inverse
z = t.inv(y)

assert torch.allclose(x, z, atol=1e-4), t
Expand All @@ -42,7 +44,39 @@ def test_univariate_transforms():

ladj = torch.diag(J).abs().log()

assert torch.allclose(ladj, t.log_abs_det_jacobian(x, y), atol=1e-4), t
assert torch.allclose(t.log_abs_det_jacobian(x, y), ladj, atol=1e-4), t

# Inverse Jacobian
J = torch.autograd.functional.jacobian(t.inv, y)

assert (torch.triu(J, diagonal=1) == 0).all(), t
assert (torch.tril(J, diagonal=-1) == 0).all(), t

ladj = torch.diag(J).abs().log()

assert torch.allclose(t.inv.log_abs_det_jacobian(y, z), ladj, atol=1e-4), t


def test_FFJTransform():
a = torch.randn(3)
f = lambda x, t: a * x
t = FFJTransform(f, time=torch.tensor(1.0))

# Call
x = randn(256, 3)
y = t(x)

assert x.shape == y.shape

# Inverse
z = t.inv(y)

assert torch.allclose(x, z, atol=1e-4)

# Jacobian
ladj = t.log_abs_det_jacobian(x, y)

assert ladj.shape == x.shape[:-1]


def test_PermutationTransform():
Expand Down
41 changes: 31 additions & 10 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@

def test_bisection():
f = torch.cos
y = torch.tensor(0.0)
a = torch.rand(256, 1) + 2.0
b = torch.rand(16)

x = bisection(f, a, b, n=18)
x = bisection(f, y, a, b, n=18)

assert x.shape == (256, 16)
assert torch.allclose(x, torch.tensor(math.pi / 2), atol=1e-4)
assert torch.allclose(f(x), torch.tensor(0.0), atol=1e-4)
assert torch.allclose(f(x), y, atol=1e-4)


def test_broadcast():
Expand Down Expand Up @@ -59,17 +60,37 @@ def test_gauss_legendre():
# Polynomial
f = lambda x: x**5 - x**2
F = lambda x: x**6 / 6 - x**3 / 3
a, b = randn(2, 256)
a, b = randn(2, 256, requires_grad=True)

area = gauss_legendre(f, a, b, n=3)

assert torch.allclose(F(b) - F(a), area, atol=1e-4)
assert torch.allclose(area, F(b) - F(a), atol=1e-4)

# Gradients
grad_a, grad_b = torch.autograd.functional.jacobian(
lambda a, b: gauss_legendre(f, a, b).sum(),
(a, b),
)
grad_a, grad_b = torch.autograd.grad(area.sum(), (a, b))

assert torch.allclose(-f(a), grad_a)
assert torch.allclose(f(b), grad_b)
assert torch.allclose(grad_a, -f(a), atol=1e-4)
assert torch.allclose(grad_b, f(b), atol=1e-4)


def test_odeint():
# Linear
alpha = torch.tensor(1.0, requires_grad=True)
t = torch.tensor(3.0, requires_grad=True)

f = lambda x, t: -alpha * x
F = lambda x, t: x * (-alpha * t).exp()

x0 = randn(256, 1, requires_grad=True)
xt = odeint(f, x0, torch.zeros_like(t), t, phi=(alpha,))

assert xt.shape == x0.shape
assert torch.allclose(xt, F(x0, t), atol=1e-4)

# Gradients
grad_x0, grad_t, grad_alpha = torch.autograd.grad(xt.sum(), (x0, t, alpha))
g_x0, g_t, g_alpha = torch.autograd.grad(F(x0, t).sum(), (x0, t, alpha))

assert torch.allclose(grad_x0, g_x0, atol=1e-4)
assert torch.allclose(grad_t, g_t, atol=1e-4)
assert torch.allclose(grad_alpha, g_alpha, atol=1e-4)
74 changes: 67 additions & 7 deletions zuko/distributions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r"""Parametrizable probability distributions."""
r"""Parameterizable probability distributions."""

import math
import torch
Expand All @@ -7,14 +7,15 @@
from torch import Tensor, Size
from torch.distributions import *
from torch.distributions import constraints
from torch.distributions.utils import _sum_rightmost
from typing import *


Distribution._validate_args = False
Distribution.arg_constraints = {}


class NormalizingFlow(TransformedDistribution):
class NormalizingFlow(Distribution):
r"""Creates a normalizing flow for a random variable :math:`X` towards a base
distribution :math:`p(Z)` through a series of :math:`n` invertible and differentiable
transformations :math:`f_1, f_2, \dots, f_n`.
Expand Down Expand Up @@ -49,18 +50,77 @@ def __init__(
transforms: List[Transform],
base: Distribution,
):
super().__init__(base, [t.inv for t in reversed(transforms)])
super().__init__()

codomain_dim = ComposeTransform(transforms).codomain.event_dim
reinterpreted = codomain_dim - len(base.event_shape)

if reinterpreted > 0:
base = Independent(base, reinterpreted)

self.transforms = transforms
self.base = base

def __repr__(self) -> str:
lines = [f'({i+1}): {t.inv}' for i, t in enumerate(reversed(self.transforms))]
lines.append(f'(base): {self.base_dist}')
lines = [f'({i + 1}): {t}' for i, t in enumerate(self.transforms)]
lines.append(f'(base): {self.base}')
lines = indent('\n'.join(lines), ' ')

return self.__class__.__name__ + '(\n' + lines + '\n)'

def expand(self, batch_shape: Size, new: Distribution = None) -> Distribution:
@property
def batch_shape(self) -> Size:
return self.base.batch_shape

@property
def event_shape(self) -> Size:
shape = self.base.event_shape

for t in reversed(self.transforms):
shape = t.inverse_shape(shape)

return shape

def expand(self, batch_shape: Size, new: Distribution = None):
new = self._get_checked_instance(NormalizingFlow, new)
return super().expand(batch_shape, new)
new.transforms = self.transforms
new.base = self.base.expand(batch_shape)

Distribution.__init__(new, batch_shape=batch_shape, validate_args=False)

return new

def log_prob(self, x: Tensor) -> Tensor:
acc = 0
event_dim = len(self.event_shape)

for t in self.transforms:
x, ladj = t.call_and_ladj(x)
acc = acc + _sum_rightmost(ladj, event_dim - t.domain.event_dim)
event_dim += t.codomain.event_dim - t.domain.event_dim

return self.base.log_prob(x) + acc

@property
def has_rsample(self) -> bool:
return self.base.has_rsample

def rsample(self, shape: Size = ()):
x = self.base.rsample(shape)

for t in reversed(self.transforms):
x = t.inv(x)

return x

def sample(self, shape: Size = ()):
with torch.no_grad():
x = self.base.sample(shape)

for t in reversed(self.transforms):
x = t.inv(x)

return x


class Joint(Distribution):
Expand Down
Loading

0 comments on commit 2d5a37f

Please sign in to comment.