Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA: Parallel Partial Emulation (aka Multitask GPs with a shared kernel) #2470

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 285 additions & 0 deletions examples/03_Multitask_Exact_GPs/Parallel_Partial_GP_Regression.ipynb

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions examples/03_Multitask_Exact_GPs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ Multi-output (vector valued functions)
- If the outputs share the same kernel and mean, you can train a `Batch Independent Multioutput GP`_.
- Otherwise, you can train a `ModelList Multioutput GP`_.

- **Partially correlated output dimensions**: for cases with a massive number of outputs.
See the `Parallel Partial GP Regression`_ example, which implements the inference strategy defined in `Gu et al., 2016`_.

.. toctree::
:maxdepth: 1
:hidden:

Multitask_GP_Regression.ipynb
Batch_Independent_Multioutput_GP.ipynb
ModelList_GP_Regression.ipynb
Parallel_Partial_GP_Regression.ipynb

Scalar function with multiple tasks
----------------------------------------
Expand All @@ -41,11 +45,17 @@ This setting should be used only when each input corresponds to a single task.
.. _Bonilla et al., 2008:
https://papers.nips.cc/paper/3189-multi-task-gaussian-process-prediction

.. _Gu et al., 2016:
https://projecteuclid.org/journals/annals-of-applied-statistics/volume-10/issue-3/Parallel-partial-Gaussian-process-emulation-for-computer-models-with-massive/10.1214/16-AOAS934.pdf

.. _Batch Independent Multioutput GP:
./Batch_Independent_Multioutput_GP.ipynb

.. _ModelList Multioutput GP:
./ModelList_GP_Regression.ipynb

.. _Parallel Partial GP Regression:
./Parallel_Partial_GP_Regression.ipynb

.. _Hadamard Multitask GP Regression:
./Hadamard_Multitask_GP_Regression.ipynb
2 changes: 2 additions & 0 deletions gpytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .multi_device_kernel import MultiDeviceKernel
from .multitask_kernel import MultitaskKernel
from .newton_girard_additive_kernel import NewtonGirardAdditiveKernel
from .parallel_partial_kernel import ParallelPartialKernel
from .periodic_kernel import PeriodicKernel
from .piecewise_polynomial_kernel import PiecewisePolynomialKernel
from .polynomial_kernel import PolynomialKernel
Expand Down Expand Up @@ -53,6 +54,7 @@
"MaternKernel",
"MultitaskKernel",
"NewtonGirardAdditiveKernel",
"ParallelPartialKernel",
"PeriodicKernel",
"PiecewisePolynomialKernel",
"PolynomialKernel",
Expand Down
62 changes: 62 additions & 0 deletions gpytorch/kernels/parallel_partial_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env python3

from linear_operator import to_linear_operator
from linear_operator.operators import BlockInterleavedLinearOperator

from .kernel import Kernel


class ParallelPartialKernel(Kernel):
r"""
A special :class:`gpytorch.kernels.MultitaskKernel` where tasks are assumed
to be independent, and a single, common kernel is used for all tasks.

Given a base covariance module to be used for the data, :math:`K_{XX}`,
this kernel returns :math:`K = I_T \otimes K_{XX}`, where :math:`T` is the
number of tasks.

.. note::

Note that, in this construction, it is crucial that all coordinates (or
tasks) share the same kernel, with the same kernel parameters. The
simplification of the inter-task kernel leads to computational
savings if the number of tasks is large. If this were not the case
(for example, when using the batch-independent Gaussian Process
construction), then each task would have a different design correlation
matrix, requiring the inversion of an `n x n` matrix at each
coordinate, where `n` is the number of data points. Furthermore, when
training the Gaussian Process surrogate, there is only one set of
kernel parameters to be estimated, instead of one for every coordinate.

:param ~gpytorch.kernels.Kernel covar_module: Kernel to use as the data kernel.
:param int num_tasks: Number of tasks.
:param dict kwargs: Additional arguments to pass to the kernel.

Example:
"""

def __init__(
self,
covar_module: Kernel,
num_tasks: int,
**kwargs,
):
super(ParallelPartialKernel, self).__init__(**kwargs)
self.covar_module = covar_module
self.num_tasks = num_tasks

def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
if last_dim_is_batch:
raise RuntimeError("ParallelPartialKernel does not accept the last_dim_is_batch argument.")
covar_x = to_linear_operator(self.covar_module.forward(x1, x2, **params))
res = BlockInterleavedLinearOperator(covar_x.repeat(self.num_tasks, 1, 1))
return res.diagonal(dim1=-1, dim2=-2) if diag else res

def num_outputs_per_input(self, x1, x2):
"""
Given `n` data points `x1` and `m` datapoints `x2`, this parallel
partial kernel returns an `(n*num_tasks) x (m*num_tasks)`
block-diagonal covariance matrix with `num_tasks` blocks of shape
`n x m` on the diagonal.
"""
return self.num_tasks
92 changes: 92 additions & 0 deletions test/examples/test_parallel_partial_gp_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/usr/bin/env python3

import os
import random
import unittest
from math import pi

import torch

import gpytorch
from gpytorch.distributions import MultitaskMultivariateNormal
from gpytorch.kernels import ParallelPartialKernel, RBFKernel
from gpytorch.likelihoods import MultitaskGaussianLikelihood
from gpytorch.means import ConstantMean, MultitaskMean


# Four sinusoidal functions with noise N(0, 0.1)
def eval_functions(train_x, noisy=True):
train_y1 = torch.sin(train_x * (2 * pi)) + (torch.randn(train_x.size()) * 0.1 if noisy else 0)
train_y2 = torch.cos(train_x * (2 * pi)) + (torch.randn(train_x.size()) * 0.1 if noisy else 0)
train_y3 = torch.cos(train_x * pi) + (torch.randn(train_x.size()) * 0.1 if noisy else 0)
train_y4 = torch.cos(train_x * pi) + (torch.randn(train_x.size()) * 0.1 if noisy else 0)
return torch.stack([train_y1, train_y2, train_y3, train_y4], -1)


class ParallelPartialGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(ParallelPartialGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = MultitaskMean(ConstantMean(), num_tasks=train_y.shape[1])
self.covar_module = ParallelPartialKernel(RBFKernel(), num_tasks=train_y.shape[1])

def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultitaskMultivariateNormal(mean_x, covar_x)


class TestParallelPartialGPRegression(unittest.TestCase):
def setUp(self):
if os.getenv("UNLOCK_SEED") is None or os.getenv("UNLOCK_SEED").lower() == "false":
self.rng_state = torch.get_rng_state()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
random.seed(0)

def tearDown(self):
if hasattr(self, "rng_state"):
torch.set_rng_state(self.rng_state)

def test_parallel_partial_gp_mean_abs_error(self):

# Get training outputs
train_x = torch.linspace(0, 1, 100)
train_y = eval_functions(train_x)

# Likelihood and model
likelihood = MultitaskGaussianLikelihood(num_tasks=train_y.shape[1])
model = ParallelPartialGPModel(train_x, train_y, likelihood)

# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

# Training
n_iter = 50
for _ in range(n_iter):
optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y)
loss.backward()
optimizer.step()

# Test the model
model.eval()
likelihood.eval()
test_x = torch.linspace(0, 1, 51)
test_y = eval_functions(test_x, noisy=False)
test_preds = likelihood(model(test_x)).mean
for task in range(train_y.shape[1]):
mean_abs_error_task = torch.mean(torch.abs(test_y[:, task] - test_preds[:, task]))
self.assertLess(mean_abs_error_task.item(), 0.05)


if __name__ == "__main__":
unittest.main()