Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Fantasy Models for Multitask GPs are broken #2577

Open
ancorso opened this issue Sep 1, 2024 · 4 comments
Open

[Bug] Fantasy Models for Multitask GPs are broken #2577

ancorso opened this issue Sep 1, 2024 · 4 comments
Labels

Comments

@ancorso
Copy link

ancorso commented Sep 1, 2024

🐛 Bug

Getting a fantasy model for a simple multi-task GP throws an error

To reproduce

Here is a minimum working example of the bug

import torch
import gpytorch

class MultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, n_tasks):
        super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.MultitaskMean(
            gpytorch.means.ConstantMean(), num_tasks=n_tasks
        )
        self.covar_module = gpytorch.kernels.MultitaskKernel(
            gpytorch.kernels.RBFKernel(), num_tasks=n_tasks, rank=1
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)


input_dim = 1
output_dim = 2
n_train = 10
train_x = torch.randn(n_train, input_dim)
train_y = torch.randn(n_train, output_dim)

likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=output_dim)
model = MultitaskGPModel(train_x, train_y, likelihood, output_dim)

model.train()
model.eval()

# get a posterior to fill in caches
model(torch.randn(n_train, input_dim))

# Generate some new data and get fantasy model
n_new = 5
new_x = torch.randn(n_new, input_dim)
new_y = torch.randn(n_new, output_dim)

model.get_fantasy_model(new_x, new_y)

** Stack trace/error message **

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/acorso/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/models/exact_gp.py", line 239, in get_fantasy_model
    new_model.prediction_strategy = old_pred_strat.get_fantasy_strategy(
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/acorso/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/models/exact_prediction_strategies.py", line 196, in get_fantasy_strategy
    small_system_rhs = targets - fant_mean - ftcm
                       ~~~~~~~~~~~~~~~~~~~~^~~~~~
RuntimeError: The size of tensor a (2) must match the size of tensor b (10) at non-singleton dimension 1

Expected Behavior

The fantasy model with the appropriately updated cache should be returned

System information

  • gpytorch 1.12
  • torch 2.4.0+cu121
  • MacOS

Additional context

This has already been a topic of discussion #800 and #805 and a PR was merged that supposedly implemented this feature #2317. However, the test that was added only works because only a single additional datapoint was added to produce the fantasy model. If you switch n_new=1 in the example I provide above, it also runs without error but I'm skeptical that the right thing is happening, if it doesn't work for more than 1 additional point.

@ancorso ancorso added the bug label Sep 1, 2024
@ancorso
Copy link
Author

ancorso commented Sep 1, 2024

With some investigation, it seems like this line

ftcm = torch.einsum(prefix + "...yz,...z->" + prefix + "...y", [fant_train_covar, self.mean_cache])
is producing a tensor that is 1D and length m*d_out where d_out is the output dimension of the GP. In the next line
small_system_rhs = targets - fant_mean - ftcm
the terms target and fant_mean are tensors of size m x d_out. So this seems like it would be an easy fix with respect to reshaping, but I'm not sure which is the correct shape to use. If someone can weigh in on what is correct here, I am happy to submit a PR with a new test example that covers this case. @gpleiss?

@gpleiss
Copy link
Member

gpleiss commented Sep 3, 2024

@ancorso we are in the middle of a reworking of the prediction strategies code (timeline tbd) for a 2.0 release. However, we'd accept a bugfix PR for the time being (as long as it's not too much work on your end!)

@Balandat
Copy link
Collaborator

Balandat commented Sep 5, 2024

cc @hvarfner

@williamjsdavis
Copy link

@gpleiss We have put together a bugfix PR here #2587, which passes tests that we ran locally. If you have some time, let us know what you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants