Skip to content

Commit

Permalink
fix failing unittests and increase coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstanton committed Dec 23, 2021
1 parent d8bd497 commit 328ebd0
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 23 deletions.
4 changes: 2 additions & 2 deletions gpytorch/distributions/multitask_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,12 @@ def get_base_samples(self, sample_shape=torch.Size()):
return base_samples.view(new_shape).transpose(-1, -2).contiguous()
return base_samples.view(*sample_shape, *self._output_shape)

def log_prob(self, value):
def log_prob(self, value, combine_terms=True):
if not self._interleaved:
# flip shape of last two dimensions
new_shape = value.shape[:-2] + value.shape[:-3:-1]
value = value.view(new_shape).transpose(-1, -2).contiguous()
return super().log_prob(value.view(*value.shape[:-2], -1))
return super().log_prob(value.view(*value.shape[:-2], -1), combine_terms)

@property
def mean(self):
Expand Down
4 changes: 3 additions & 1 deletion gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def log_prob(self, value, combine_terms=True):
# Get log determininant and first part of quadratic form
covar = covar.evaluate_kernel()
inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
norm_const = diff.size(-1) * math.log(2 * math.pi)
norm_const = torch.tensor(
diff.size(-1) * math.log(2 * math.pi)
).to(inv_quad)
split_terms = [inv_quad, logdet, norm_const]

if combine_terms:
Expand Down
5 changes: 3 additions & 2 deletions gpytorch/mlls/exact_marginal_log_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood):
:param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model
:param ~gpytorch.models.ExactGP model: The exact GP model
:param ~bool combine_terms (optional): If `False`, the MLL call returns each MLL term separately
Example:
>>> # model is a gpytorch.models.ExactGP
Expand All @@ -30,10 +31,10 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood):
>>> loss.backward()
"""

def __init__(self, likelihood, model):
def __init__(self, likelihood, model, combine_terms=True):
if not isinstance(likelihood, _GaussianLikelihoodBase):
raise RuntimeError("Likelihood must be Gaussian for exact inference")
super(ExactMarginalLogLikelihood, self).__init__(likelihood, model)
super(ExactMarginalLogLikelihood, self).__init__(likelihood, model, combine_terms)

def _add_other_terms(self, res, params):
# Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models)
Expand Down
22 changes: 11 additions & 11 deletions gpytorch/mlls/leave_one_out_pseudo_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class LeaveOneOutPseudoLikelihood(ExactMarginalLogLikelihood):
:param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model
:param ~gpytorch.models.ExactGP model: The exact GP model
:param ~bool combine_terms (optional): If `False`, the MLL call returns each MLL term separately
Example:
>>> # model is a gpytorch.models.ExactGP
Expand All @@ -40,11 +41,6 @@ class LeaveOneOutPseudoLikelihood(ExactMarginalLogLikelihood):
>>> loss.backward()
"""

def __init__(self, likelihood, model):
super().__init__(likelihood=likelihood, model=model)
self.likelihood = likelihood
self.model = model

def forward(self, function_dist: MultivariateNormal, target: Tensor, *params) -> Tensor:
r"""
Computes the leave one out likelihood given :math:`p(\mathbf f)` and `\mathbf y`
Expand All @@ -60,12 +56,16 @@ def forward(self, function_dist: MultivariateNormal, target: Tensor, *params) ->
identity = torch.eye(*L.shape[-2:], dtype=m.dtype, device=m.device)
sigma2 = 1.0 / L._cholesky_solve(identity, upper=False).diagonal(dim1=-1, dim2=-2) # 1 / diag(inv(K))
mu = target - L._cholesky_solve((target - m).unsqueeze(-1), upper=False).squeeze(-1) * sigma2
term1 = -0.5 * sigma2.log()
term2 = -0.5 * (target - mu).pow(2.0) / sigma2
res = (term1 + term2).sum(dim=-1)

res = self._add_other_terms(res, params)

# Scale by the amount of data we have and then add on the scaled constant
num_data = target.size(-1)
return res.div_(num_data) - 0.5 * math.log(2 * math.pi)
term1 = sigma2.log().sum(-1)
term2 = ((target - mu).pow(2.0) / sigma2).sum(-1)
norm_const = torch.tensor(num_data * math.log(2 * math.pi)).to(term1)
other_term = self._add_other_terms(torch.zeros_like(term1), params)
split_terms = [term1, term2, norm_const, other_term]

if self.combine_terms:
return -0.5 / num_data * sum(split_terms)
else:
return [-0.5 / num_data * term for term in split_terms]
16 changes: 14 additions & 2 deletions test/distributions/test_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,20 +219,32 @@ def test_log_prob(self, cuda=False):
var = torch.randn(4, device=device, dtype=dtype).abs_()
values = torch.randn(4, device=device, dtype=dtype)

res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values)
mvn = MultivariateNormal(mean, DiagLazyTensor(var))
res = mvn.log_prob(values)
actual = TMultivariateNormal(mean, torch.eye(4, device=device, dtype=dtype) * var).log_prob(values)
self.assertLess((res - actual).div(res).abs().item(), 1e-2)

res2 = mvn.log_prob(values, combine_terms=False)
assert len(res2) == 3
res2 = sum(res2)
self.assertLess((res2 - actual).div(res).abs().item(), 1e-2)

mean = torch.randn(3, 4, device=device, dtype=dtype)
var = torch.randn(3, 4, device=device, dtype=dtype).abs_()
values = torch.randn(3, 4, device=device, dtype=dtype)

res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values)
mvn = MultivariateNormal(mean, DiagLazyTensor(var))
res = mvn.log_prob(values)
actual = TMultivariateNormal(
mean, var.unsqueeze(-1) * torch.eye(4, device=device, dtype=dtype).repeat(3, 1, 1)
).log_prob(values)
self.assertLess((res - actual).div(res).abs().norm(), 1e-2)

res2 = mvn.log_prob(values, combine_terms=False)
assert len(res2) == 3
res2 = sum(res2)
self.assertLess((res2 - actual).div(res).abs().norm(), 1e-2)

def test_log_prob_cuda(self):
if torch.cuda.is_available():
with least_used_cuda_device():
Expand Down
59 changes: 59 additions & 0 deletions test/mlls/test_exact_marginal_log_likelihood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import unittest

import torch

import gpytorch

from .test_leave_one_out_pseudo_likelihood import ExactGPModel


class TestExactMarginalLogLikelihood(unittest.TestCase):
def get_data(self, shapes, combine_terms, dtype=None, device=None):
train_x = torch.rand(*shapes, dtype=dtype, device=device, requires_grad=True)
train_y = torch.sin(train_x[..., 0]) + torch.cos(train_x[..., 1])
likelihood = gpytorch.likelihoods.GaussianLikelihood().to(dtype=dtype, device=device)
model = ExactGPModel(train_x, train_y, likelihood).to(dtype=dtype, device=device)
exact_mll = gpytorch.mlls.ExactMarginalLogLikelihood(
likelihood=likelihood,
model=model,
combine_terms=combine_terms
)
return train_x, train_y, exact_mll

def test_smoke(self):
"""Make sure the exact_mll works without batching."""
train_x, train_y, exact_mll = self.get_data([5, 2], combine_terms=True)
output = exact_mll.model(train_x)
loss = -exact_mll(output, train_y)
loss.backward()
self.assertTrue(train_x.grad is not None)

train_x, train_y, exact_mll = self.get_data([5, 2], combine_terms=False)
output = exact_mll.model(train_x)
mll_out = exact_mll(output, train_y)
loss = -1 * sum(mll_out)
loss.backward()
assert len(mll_out) == 4
self.assertTrue(train_x.grad is not None)

def test_smoke_batch(self):
"""Make sure the exact_mll works without batching."""
train_x, train_y, exact_mll = self.get_data([3, 3, 3, 5, 2], combine_terms=True)
output = exact_mll.model(train_x)
loss = -exact_mll(output, train_y)
assert loss.shape == (3, 3, 3)
loss.sum().backward()
self.assertTrue(train_x.grad is not None)

train_x, train_y, exact_mll = self.get_data([3, 3, 3, 5, 2], combine_terms=False)
output = exact_mll.model(train_x)
mll_out = exact_mll(output, train_y)
loss = -1 * sum(mll_out)
assert len(mll_out) == 4
assert loss.shape == (3, 3, 3)
loss.sum().backward()
self.assertTrue(train_x.grad is not None)


if __name__ == "__main__":
unittest.main()
31 changes: 26 additions & 5 deletions test/mlls/test_leave_one_out_pseudo_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,57 @@ def forward(self, x):


class TestLeaveOneOutPseudoLikelihood(unittest.TestCase):
def get_data(self, shapes, dtype=None, device=None):
def get_data(self, shapes, combine_terms, dtype=None, device=None):
train_x = torch.rand(*shapes, dtype=dtype, device=device, requires_grad=True)
train_y = torch.sin(train_x[..., 0]) + torch.cos(train_x[..., 1])
likelihood = gpytorch.likelihoods.GaussianLikelihood().to(dtype=dtype, device=device)
model = ExactGPModel(train_x, train_y, likelihood).to(dtype=dtype, device=device)
loocv = gpytorch.mlls.LeaveOneOutPseudoLikelihood(likelihood=likelihood, model=model)
loocv = gpytorch.mlls.LeaveOneOutPseudoLikelihood(
likelihood=likelihood,
model=model,
combine_terms=combine_terms
)
return train_x, train_y, loocv

def test_smoke(self):
"""Make sure the loocv works without batching."""
train_x, train_y, loocv = self.get_data([5, 2])
train_x, train_y, loocv = self.get_data([5, 2], combine_terms=True)
output = loocv.model(train_x)
loss = -loocv(output, train_y)
loss.backward()
self.assertTrue(train_x.grad is not None)

train_x, train_y, loocv = self.get_data([5, 2], combine_terms=False)
output = loocv.model(train_x)
mll_out = loocv(output, train_y)
loss = -1 * sum(mll_out)
loss.backward()
assert len(mll_out) == 4
self.assertTrue(train_x.grad is not None)

def test_smoke_batch(self):
"""Make sure the loocv works without batching."""
train_x, train_y, loocv = self.get_data([3, 3, 3, 5, 2])
train_x, train_y, loocv = self.get_data([3, 3, 3, 5, 2], combine_terms=True)
output = loocv.model(train_x)
loss = -loocv(output, train_y)
assert loss.shape == (3, 3, 3)
loss.sum().backward()
self.assertTrue(train_x.grad is not None)

train_x, train_y, loocv = self.get_data([3, 3, 3, 5, 2], combine_terms=False)
output = loocv.model(train_x)
mll_out = loocv(output, train_y)
loss = -1 * sum(mll_out)
assert len(mll_out) == 4
assert loss.shape == (3, 3, 3)
loss.sum().backward()
self.assertTrue(train_x.grad is not None)

def test_check_bordered_system(self):
"""Make sure that the bordered system solves match the naive solution."""
n = 5
# Compute the pseudo-likelihood via the bordered systems in O(n^3)
train_x, train_y, loocv = self.get_data([n, 2], dtype=torch.float64)
train_x, train_y, loocv = self.get_data([n, 2], combine_terms=True, dtype=torch.float64)
output = loocv.model(train_x)
loocv_1 = loocv(output, train_y)

Expand Down

0 comments on commit 328ebd0

Please sign in to comment.