Skip to content

Commit

Permalink
add combine_terms option to exact MLL
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstanton committed Dec 16, 2021
1 parent 5a0ff6b commit ba187db
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
10 changes: 7 additions & 3 deletions gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def lazy_covariance_matrix(self):
else:
return lazify(super().covariance_matrix)

def log_prob(self, value):
def log_prob(self, value, combine_terms=True):
if settings.fast_computations.log_prob.off():
return super().log_prob(value)

Expand All @@ -167,9 +167,13 @@ def log_prob(self, value):
# Get log determininant and first part of quadratic form
covar = covar.evaluate_kernel()
inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
norm_const = diff.size(-1) * math.log(2 * math.pi)
split_terms = [inv_quad, logdet, norm_const]

res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)])
return res
if combine_terms:
return -0.5 * sum(split_terms)
else:
return [-0.5 * term for term in split_terms]

def rsample(self, sample_shape=torch.Size(), base_samples=None):
covar = self.lazy_covariance_matrix
Expand Down
16 changes: 13 additions & 3 deletions gpytorch/mlls/exact_marginal_log_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3

import torch

from ..distributions import MultivariateNormal
from ..likelihoods import _GaussianLikelihoodBase
from .marginal_log_likelihood import MarginalLogLikelihood
Expand Down Expand Up @@ -59,9 +61,17 @@ def forward(self, function_dist, target, *params):

# Get the log prob of the marginal distribution
output = self.likelihood(function_dist, *params)
res = output.log_prob(target)
res = self._add_other_terms(res, params)
res = output.log_prob(target, combine_terms=self.combine_terms)

# Scale by the amount of data we have
num_data = function_dist.event_shape.numel()
return res.div_(num_data)

if self.combine_terms:
res = self._add_other_terms(res, params)
return res.div(num_data)
else:
norm_const = res[-1]
other_terms = torch.zeros_like(norm_const)
other_terms = self._add_other_terms(other_terms, params)
res.append(other_terms)
return [term.div(num_data) for term in res]
3 changes: 2 additions & 1 deletion gpytorch/mlls/marginal_log_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class MarginalLogLikelihood(Module):
these functions must be negated for optimization).
"""

def __init__(self, likelihood, model):
def __init__(self, likelihood, model, combine_terms=True):
super(MarginalLogLikelihood, self).__init__()
if not isinstance(model, GP):
raise RuntimeError(
Expand All @@ -35,6 +35,7 @@ def __init__(self, likelihood, model):
)
self.likelihood = likelihood
self.model = model
self.combine_terms = combine_terms

def forward(self, output, target, **kwargs):
r"""
Expand Down

0 comments on commit ba187db

Please sign in to comment.