Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to change covariance matrix type for GMM class #50

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 102 additions & 16 deletions zuko/flows/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,73 @@
import torch
import torch.nn as nn

from math import prod
from torch import Tensor
from torch.distributions import Distribution, MultivariateNormal

# isort: local
from .core import LazyDistribution
from ..distributions import Mixture
from ..nn import MLP
from ..utils import unpack
from math import prod
from torch import Tensor
from torch.distributions import (
Distribution,
Independent,
LowRankMultivariateNormal,
MultivariateNormal,
Normal,
)


def _determine_shapes(components, features, covariance_type, tied, cov_rank):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following pattern would allow to reduce code duplication

leading = 1 if tied else components

if covariance_type == 'full':
    shapes.extend([
        (leading, features),
        (leading, features * (features - 1) // 2),
    ])
elif ...

shapes = [
(components,), # probabilities
(components, features), # mean
]
if covariance_type == 'full' and not tied:
shapes.extend([
(components, features), # diagonal
(components, features * (features - 1) // 2), # off diagonal
])
elif covariance_type == 'full' and tied:
shapes.extend([
(1, features), # diagonal
(1, features * (features - 1) // 2), # off diagonal
])
elif covariance_type == 'lowrank' and not tied:
if cov_rank is None:
raise ValueError('cov_rank must be specified when covariance_type is lowrank')
shapes.extend([
(components, features), # diagonal
(components, features * cov_rank), # low-rank
])
elif covariance_type == 'lowrank' and tied:
if cov_rank is None:
raise ValueError('cov_rank must be specified when covariance_type is lowrank')
shapes.extend([
(1, features), # diagonal
(1, features * cov_rank), # low-rank
])
elif covariance_type == 'diag' and not tied:
shapes.extend([
(components, features), # diagonal
])
elif covariance_type == 'diag' and tied:
shapes.extend([
(1, features), # diagonal
])
elif covariance_type == 'spherical' and not tied:
shapes.extend([
(components, 1), # diagonal
])
elif covariance_type == 'spherical' and tied:
shapes.extend([
(1, 1), # diagonal
])
else:
raise ValueError(
f'Invalid covariance type: {covariance_type} (choose from full, lowrank, diag, or spherical)'
)
return shapes


class GMM(LazyDistribution):
Expand All @@ -30,6 +88,18 @@ class GMM(LazyDistribution):
features: The number of features.
context: The number of context features.
components: The number of components :math:`K` in the mixture.
covariance_type: String describing the type of covariance parameters to use. Must be one of:

- ‘full’: each component has its own full rank covariance matrix.

- ’lowrank’: each component has its own low-rank covariance matrix.

- ‘diag’: each component has its own diagonal covariance matrix.

- ‘spherical’: each component has its own single variance.

tied: Whether to use tied covariance matrices. Tied covariances share the same parameters across components.
cov_rank: The rank of the low-rank covariance matrix. Only used when `covariance_type` is 'lowrank'.
kwargs: Keyword arguments passed to :class:`zuko.nn.MLP`.
"""

Expand All @@ -38,17 +108,18 @@ def __init__(
features: int,
context: int = 0,
components: int = 2,
covariance_type: str = 'full',
tied: bool = False,
cov_rank: int = None,
**kwargs,
):
super().__init__()

shapes = [
(components,), # probabilities
(components, features), # mean
(components, features), # diagonal
(components, features * (features - 1) // 2), # off diagonal
]
shapes = _determine_shapes(components, features, covariance_type, tied, cov_rank)

self.covariance_type = covariance_type
self.tied = tied
self.cov_rank = cov_rank
self.shapes = shapes
self.total = sum(prod(s) for s in shapes)

Expand All @@ -64,10 +135,25 @@ def forward(self, c: Tensor = None) -> Distribution:
phi = self.hyper(c)
phi = unpack(phi, self.shapes)

logits, loc, diag, tril = phi

scale = torch.diag_embed(diag.exp() + 1e-5)
mask = torch.tril(torch.ones_like(scale, dtype=bool), diagonal=-1)
scale = torch.masked_scatter(scale, mask, tril)

return Mixture(MultivariateNormal(loc=loc, scale_tril=scale), logits)
if self.covariance_type == 'full':
logits, loc, diag, tril = phi
scale = torch.diag_embed(diag.exp() + 1e-5)
mask = torch.tril(torch.ones_like(scale, dtype=bool), diagonal=-1)
scale = torch.masked_scatter(scale, mask, tril)
# expanded automatically for tied covariances
return Mixture(MultivariateNormal(loc=loc, scale_tril=scale), logits)

if self.covariance_type == 'lowrank':
logits, loc, diag, lowrank = phi
diag = diag.exp() + 1e-5
lowrank = lowrank.reshape(lowrank.shape[0], lowrank.shape[1], self.cov_rank)
# expanded automatically for tied covariances
return Mixture(
LowRankMultivariateNormal(loc=loc, cov_factor=lowrank, cov_diag=diag), logits
)

elif self.covariance_type in ['diag', 'spherical']:
logits, loc, diag = phi
diag = diag.exp() + 1e-5
# expanded automatically for spherical and tied covariance
return Mixture(Independent(Normal(loc, diag), 1), logits)
Loading