From ba187db978106a0f1c7ef3c5426dfba2d1f3efc5 Mon Sep 17 00:00:00 2001 From: samuelstanton Date: Thu, 16 Dec 2021 12:57:49 -0500 Subject: [PATCH] add combine_terms option to exact MLL --- gpytorch/distributions/multivariate_normal.py | 10 +++++++--- gpytorch/mlls/exact_marginal_log_likelihood.py | 16 +++++++++++++--- gpytorch/mlls/marginal_log_likelihood.py | 3 ++- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/gpytorch/distributions/multivariate_normal.py b/gpytorch/distributions/multivariate_normal.py index 06b8ad6f5..e80ed78b6 100644 --- a/gpytorch/distributions/multivariate_normal.py +++ b/gpytorch/distributions/multivariate_normal.py @@ -142,7 +142,7 @@ def lazy_covariance_matrix(self): else: return lazify(super().covariance_matrix) - def log_prob(self, value): + def log_prob(self, value, combine_terms=True): if settings.fast_computations.log_prob.off(): return super().log_prob(value) @@ -167,9 +167,13 @@ def log_prob(self, value): # Get log determininant and first part of quadratic form covar = covar.evaluate_kernel() inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True) + norm_const = diff.size(-1) * math.log(2 * math.pi) + split_terms = [inv_quad, logdet, norm_const] - res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)]) - return res + if combine_terms: + return -0.5 * sum(split_terms) + else: + return [-0.5 * term for term in split_terms] def rsample(self, sample_shape=torch.Size(), base_samples=None): covar = self.lazy_covariance_matrix diff --git a/gpytorch/mlls/exact_marginal_log_likelihood.py b/gpytorch/mlls/exact_marginal_log_likelihood.py index f33408868..b05fbf323 100644 --- a/gpytorch/mlls/exact_marginal_log_likelihood.py +++ b/gpytorch/mlls/exact_marginal_log_likelihood.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +import torch + from ..distributions import MultivariateNormal from ..likelihoods import _GaussianLikelihoodBase from .marginal_log_likelihood import MarginalLogLikelihood @@ -59,9 +61,17 @@ def forward(self, function_dist, target, *params): # Get the log prob of the marginal distribution output = self.likelihood(function_dist, *params) - res = output.log_prob(target) - res = self._add_other_terms(res, params) + res = output.log_prob(target, combine_terms=self.combine_terms) # Scale by the amount of data we have num_data = function_dist.event_shape.numel() - return res.div_(num_data) + + if self.combine_terms: + res = self._add_other_terms(res, params) + return res.div(num_data) + else: + norm_const = res[-1] + other_terms = torch.zeros_like(norm_const) + other_terms = self._add_other_terms(other_terms, params) + res.append(other_terms) + return [term.div(num_data) for term in res] diff --git a/gpytorch/mlls/marginal_log_likelihood.py b/gpytorch/mlls/marginal_log_likelihood.py index be696c9c8..541b2000f 100644 --- a/gpytorch/mlls/marginal_log_likelihood.py +++ b/gpytorch/mlls/marginal_log_likelihood.py @@ -25,7 +25,7 @@ class MarginalLogLikelihood(Module): these functions must be negated for optimization). """ - def __init__(self, likelihood, model): + def __init__(self, likelihood, model, combine_terms=True): super(MarginalLogLikelihood, self).__init__() if not isinstance(model, GP): raise RuntimeError( @@ -35,6 +35,7 @@ def __init__(self, likelihood, model): ) self.likelihood = likelihood self.model = model + self.combine_terms = combine_terms def forward(self, output, target, **kwargs): r"""