Skip to content

Commit

Permalink
✨ Add Glow-like multi-scale flow
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Dec 31, 2022
1 parent 3e9c68d commit 0729f25
Show file tree
Hide file tree
Showing 6 changed files with 683 additions and 18 deletions.
48 changes: 48 additions & 0 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,51 @@ def test_autoregressive_transforms():
assert (torch.triu(J, diagonal=1) == 0).all(), t
assert (torch.tril(J[:4, :4], diagonal=-1) == 0).all(), t
assert (torch.tril(J[4:, 4:], diagonal=-1) == 0).all(), t


def test_Glow(tmp_path):
flow = Glow((3, 32, 32), context=[5, 0, 5])

# Evaluation of log_prob
x, y = randn(8, 3, 32, 32), [randn(5, 16, 16), None, randn(8, 5, 4, 4)]
log_p = flow(y).log_prob(x)

assert log_p.shape == (8,)
assert log_p.requires_grad

flow.zero_grad(set_to_none=True)
loss = -log_p.mean()
loss.backward()

for p in flow.parameters():
assert p.grad is not None

# Sampling
x = flow(y).sample()

assert x.shape == (8, 3, 32, 32)

# Reparameterization trick
x = flow(y).rsample()

flow.zero_grad(set_to_none=True)
loss = x.square().sum().sqrt()
loss.backward()

for p in flow.parameters():
assert p.grad is not None

# Saving
torch.save(flow, tmp_path / 'flow.pth')

# Loading
flow_bis = torch.load(tmp_path / 'flow.pth')

x, y = randn(3, 32, 32), [randn(5, 16, 16), None, randn(5, 4, 4)]

seed = torch.seed()
log_p = flow(y).log_prob(x)
torch.manual_seed(seed)
log_p_bis = flow_bis(y).log_prob(x)

assert torch.allclose(log_p, log_p_bis)
17 changes: 17 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,20 @@ def test_MonotonicMLP():
J = torch.autograd.functional.jacobian(net, x)

assert (J >= 0).all()


def test_FCN():
net = FCN(3, 5)

# Non-batched
x = randn(3, 64, 64)
y = net(x)

assert y.shape == (5, 64, 64)
assert y.requires_grad

# Batched
x = randn(8, 3, 32, 32)
y = net(x)

assert y.shape == (8, 5, 32, 32)
67 changes: 57 additions & 10 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,49 @@ def test_univariate_transforms():
assert torch.allclose(t.inv.log_abs_det_jacobian(y, z), ladj, atol=1e-4), t


def test_multivariate_transforms():
ts = [
LULinearTransform(randn(3, 3), dim=-2),
PermutationTransform(torch.randperm(3), dim=-2),
PixelShuffleTransform(dim=-2),
]

for t in ts:
# Shapes
x = randn(256, 3, 8)
y = t(x)

assert t.forward_shape(x.shape) == y.shape, t
assert t.inverse_shape(y.shape) == x.shape, t

# Inverse
z = t.inv(y)

assert x.shape == z.shape, t
assert torch.allclose(x, z, atol=1e-4), t

# Jacobian
x = randn(3, 8)
y = t(x)

jacobian = torch.autograd.functional.jacobian(t, x)
jacobian = jacobian.reshape(3 * 8, 3 * 8)

_, ladj = torch.slogdet(jacobian)

assert torch.allclose(t.log_abs_det_jacobian(x, y), ladj, atol=1e-4), t

# Inverse Jacobian
z = t.inv(y)

jacobian = torch.autograd.functional.jacobian(t.inv, y)
jacobian = jacobian.reshape(3 * 8, 3 * 8)

_, ladj = torch.slogdet(jacobian)

assert torch.allclose(t.inv.log_abs_det_jacobian(y, z), ladj, atol=1e-4), t


def test_FFJTransform():
a = torch.randn(3)
f = lambda x, t: a * x
Expand All @@ -80,20 +123,24 @@ def test_FFJTransform():
assert ladj.shape == x.shape[:-1]


def test_PermutationTransform():
t = PermutationTransform(torch.randperm(8))
def test_DropTransform():
dist = Normal(randn(3), abs(randn(3)) + 1)
t = DropTransform(dist)

x = torch.randn(256, 8)
# Call
x = randn(256, 5)
y = t(x)

assert x.shape == y.shape

match = x[:, :, None] == y[:, None, :]

assert (match.sum(dim=-1) == 1).all()
assert (match.sum(dim=-2) == 1).all()
assert t.forward_shape(x.shape) == y.shape
assert t.inverse_shape(y.shape) == x.shape

# Inverse
z = t.inv(y)

assert x.shape == z.shape
assert (x == z).all()
assert not torch.allclose(x, z)

# Jacobian
ladj = t.log_abs_det_jacobian(x, y)

assert ladj.shape == (256,)
192 changes: 192 additions & 0 deletions zuko/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
'NAF',
'FreeFormJacobianTransform',
'CNF',
'ConvCouplingTransform',
'Glow',
]

import abc
Expand Down Expand Up @@ -753,3 +755,193 @@ def __init__(
)

super().__init__(transforms, base)


class ConvCouplingTransform(TransformModule):
r"""Creates a convolution coupling transformation.
Arguments:
channels: The number of channels.
context: The number of context channels.
spatial: The number of spatial dimensions.
univariate: The univariate transformation constructor.
shapes: The shapes of the univariate transformation parameters.
kwargs: Keyword arguments passed to :class:`zuko.nn.FCN`.
"""

def __init__(
self,
channels: int,
context: int = 0,
spatial: int = 2,
univariate: Callable[..., Transform] = MonotonicAffineTransform,
shapes: List[Size] = [(), ()],
**kwargs,
):
super().__init__()

self.d = channels // 2
self.dim = -(spatial + 1)

# Univariate transformation
self.univariate = univariate
self.shapes = list(map(Size, shapes))
self.sizes = [s.numel() for s in self.shapes]

# Hyper network
kwargs.setdefault('activation', nn.ELU)
kwargs.setdefault('normalize', True)

self.hyper = FCN(
in_channels=self.d + context,
out_channels=(channels - self.d) * sum(self.sizes),
spatial=spatial,
**kwargs,
)

def extra_repr(self) -> str:
base = self.univariate(*map(torch.randn, self.shapes))

return f'(base): {base}'

def meta(self, y: Tensor, x: Tensor) -> Transform:
if y is not None:
x = torch.cat(broadcast(x, y, ignore=abs(self.dim)), dim=self.dim)

total = sum(self.sizes)

phi = self.hyper(x)
phi = phi.unflatten(self.dim, (phi.shape[self.dim] // total, total))
phi = phi.movedim(self.dim, -1)
phi = phi.split(self.sizes, -1)
phi = (p.unflatten(-1, s + (1,)) for p, s in zip(phi, self.shapes))
phi = (p.squeeze(-1) for p in phi)

return self.univariate(*phi)

def forward(self, y: Tensor = None) -> Transform:
return CouplingTransform(partial(self.meta, y), self.d, self.dim)


class Glow(DistributionModule):
r"""Creates a Glow-like multi-scale flow.
References:
| Glow: Generative Flow with Invertible 1x1 Convolutions (Kingma et al., 2018)
| https://arxiv.org/abs/1807.03039
Arguments:
shape: The shape of a sample.
context: The number of context channels at each scale.
transforms: The number of coupling transformations at each scale.
kwargs: Keyword arguments passed to :class:`ConvCouplingTransform`.
"""

def __init__(
self,
shape: Size,
context: Union[int, List[int]] = 0,
transforms: List[int] = [8, 8, 8],
**kwargs,
):
super().__init__()

channels, *space = shape
spatial = len(space)
dim = -len(shape)
scales = len(transforms)

assert all(s % 2**scales == 0 for s in space), (
f"'shape' cannot be downscaled {scales} times"
)

if isinstance(context, int):
context = [context] * len(transforms)

self.flows = nn.ModuleList()
self.bases = nn.ModuleList()

for i, K in enumerate(transforms):
flow = []
flow.append(Unconditional(PixelShuffleTransform, dim=dim))

channels = channels * 2**spatial
space = [s // 2 for s in space]

for _ in range(K):
flow.extend([
Unconditional(
PermutationTransform,
torch.randperm(channels),
dim=dim,
buffer=True,
),
Unconditional(
LULinearTransform,
torch.eye(channels),
dim=dim,
),
ConvCouplingTransform(
channels=channels,
context=context[i],
spatial=spatial,
**kwargs,
),
])

if i < len(transforms) - 1:
drop = channels // 2
else:
drop = channels

self.flows.append(nn.ModuleList(flow))
self.bases.append(
Unconditional(
DiagNormal,
torch.zeros(drop, *space),
torch.ones(drop, *space),
ndims=spatial + 1,
buffer=True,
)
)

channels = channels - drop

def forward(self, y: Iterable[Tensor] = None) -> NormalizingFlow:
r"""
Arguments:
y: A sequence of contexts :math:`y_i`. There should be one context
per scale, but a context can be :py:`None`.
Returns:
A multi-scale flow :math:`p(X | y)`.
"""

if y is None:
y = [None] * len(self.flows)

# Transforms
transforms = []
context_shapes = []

for flow, base, y_i in zip(self.flows, self.bases, y):
for t in flow:
transforms.append(t(y_i))

transforms.append(DropTransform(base(y_i)))

if y_i is not None:
context_shapes.append(y_i.shape)

transform = ComposedTransform(*transforms[:-1])

# Base
base = transforms[-1].dist
dim = -len(base.event_shape)

batch_shapes = (shape[:dim] for shape in context_shapes)
batch_shape = torch.broadcast_shapes(*batch_shapes)

base = base.expand(batch_shape)

return NormalizingFlow(transform, base)
Loading

0 comments on commit 0729f25

Please sign in to comment.