Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add combine_terms option to exact MLL #1863

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

samuelstanton
Copy link
Contributor

@samuelstanton samuelstanton commented Dec 16, 2021

I've found that logging the inv_quad terms and logdet terms separately (rather than just the train loss) to be very helpful for debugging. Right now classes like VariationalELBO have a combine_terms option that allow the user to sum the terms after the MLL call. This is a nice feature, since otherwise you essentially have to pay for an extra training step just to log the terms separately.

In this PR I've demonstrated how we could go about adding this option to the subclasses of MarginalLogLikelihood, starting with the Gaussian likelihood case. There are a few unit tests that aren't passing yet, but I wanted to check and see if this feature would be approved before fixing it up.

@samuelstanton
Copy link
Contributor Author

@jacobrgardner @gpleiss any thoughts?

@gpleiss
Copy link
Member

gpleiss commented Dec 21, 2021

Yeah, this would be awesome to add!

@samuelstanton
Copy link
Contributor Author

@gpleiss how does everything look?

@@ -203,12 +203,12 @@ def get_base_samples(self, sample_shape=torch.Size()):
return base_samples.view(new_shape).transpose(-1, -2).contiguous()
return base_samples.view(*sample_shape, *self._output_shape)

def log_prob(self, value):
def log_prob(self, value, combine_terms=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I don't think we want to be adding flags to the standard log_prob call here to maintain compatibility with the MVN api in pytorch. let's have this be a _log_prob method with the log_prob just calling _log_prob(value=value, combine_terms=True) ?

@@ -142,7 +146,7 @@ def lazy_covariance_matrix(self):
else:
return lazify(super().covariance_matrix)

def log_prob(self, value):
def log_prob(self, value, combine_terms=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, change to _log_prob?

@wjmaddox
Copy link
Collaborator

wjmaddox commented Feb 3, 2022

Looks like the failing unit test was flaky.

@@ -59,9 +62,17 @@ def forward(self, function_dist, target, *params):

# Get the log prob of the marginal distribution
output = self.likelihood(function_dist, *params)
res = output.log_prob(target)
res = self._add_other_terms(res, params)
res = output.log_prob(target, combine_terms=self.combine_terms)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stylistically, the proposed change from log_prob to _log_prob is problematic here, because you would essentially be calling a "private" method publicly. More generally I think the combine_terms option is broadly useful and burying it inside the class makes it harder to use.

Personally I don't see why the GPyTorch log_prob API can't allow optional keyword arguments like combine_terms, as long as the default behavior is consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gpleiss @jacobrgardner care to weigh in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a compromise would be to just call it log_prob_terms instead of _log_prob

Comment on lines +176 to +177
split_terms = [inv_quad, logdet, norm_const]
split_terms = [-0.5 * term for term in split_terms]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
split_terms = [inv_quad, logdet, norm_const]
split_terms = [-0.5 * term for term in split_terms]
split_terms = [-0.5 * inv_quad, logdet, -0.5 * norm_const]

@@ -17,6 +19,7 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood):

:param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model
:param ~gpytorch.models.ExactGP model: The exact GP model
:param ~bool combine_terms (optional): If `False`, the MLL call returns each MLL term separately
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably also describe what happens if there are "other terms" (i.e. that they are added to the return elements)

actual = TMultivariateNormal(mean, torch.eye(4, device=device, dtype=dtype) * var).log_prob(values)
self.assertLess((res - actual).div(res).abs().item(), 1e-2)

res2 = mvn.log_prob_terms(values)
assert len(res2) == 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert len(res2) == 3
self.assertEqual(len(res2), 3)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also in other places in the tests below

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants