From 8924d1b66bfe7de9d31dccbcef4974476682a268 Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Tue, 1 Oct 2024 11:32:20 -0700 Subject: [PATCH] More efficient sampling from KroneckerMultiTaskGP (#2460) Summary: ## Motivation See https://github.com/pytorch/botorch/issues/2310#issuecomment-2260743319 ```python import torch from botorch.models import KroneckerMultiTaskGP n_inputs = 10 n_tasks = 4 n_train = 2048 n_test = 1 device = torch.device("cuda:0") train_x = torch.randn(n_train, n_inputs, dtype=torch.float64, device=device) train_y = torch.randn(n_train, n_tasks, dtype=torch.float64, device=device) test_x = torch.randn(n_test, n_inputs, dtype=torch.float64, device=device) gp = KroneckerMultiTaskGP(train_x, train_y) posterior = gp.posterior(test_x) posterior.rsample(torch.Size([256, 1])) ``` The final line requires allocation of 128GB of GPU memory, because of the call to `torch.cholesky_solve` with B shaped `(256, 1, 8192, 1)` and L shaped `(8192, 8192)`. By moving the largest batch dimension to the final position, we should achieve a more efficient operation. Also fix docstring for `MultitaskGPPosterior`. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: https://github.com/pytorch/botorch/pull/2460 Test Plan: Passes unit tests (specifically `test_multitask.py`). Benchmarking results: ![image](https://github.com/user-attachments/assets/1eca54be-1ed4-43c9-bb50-a18cf24d00f5) ![image](https://github.com/user-attachments/assets/016322f6-992a-45bf-b175-e76208c11b12) ## Related PRs N/A Reviewed By: saitcakmak Differential Revision: D63678866 Pulled By: Balandat fbshipit-source-id: 6675c66dadd62934f95fabafe7b3f0155a1c0c6f --- botorch/posteriors/multitask.py | 44 ++++++++++++++++++++++++++++--- test/posteriors/test_multitask.py | 24 +++++++++++++++-- 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/botorch/posteriors/multitask.py b/botorch/posteriors/multitask.py index 03a6267dbc..76a2df43d4 100644 --- a/botorch/posteriors/multitask.py +++ b/botorch/posteriors/multitask.py @@ -36,9 +36,11 @@ def __init__( distribution: Posterior multivariate normal distribution. joint_covariance_matrix: Joint test train covariance matrix over the entire tensor. - train_train_covar: Covariance matrix of train points in the data space. - test_obs_covar: Covariance matrix of test x train points in the data space. + test_train_covar: Covariance matrix of test x train points in the data + space. train_diff: Difference between train mean and train responses. + test_mean: Test mean response. + train_train_covar: Covariance matrix of train points in the data space. train_noise: Training noise covariance. test_noise: Only used if posterior should contain observation noise. Testing noise covariance. @@ -226,7 +228,9 @@ def rsample_from_base_samples( train_diff.reshape(*train_diff.shape[:-2], -1) - updated_obs_samples ) train_covar_plus_noise = self.train_train_covar + self.train_noise - obs_solve = train_covar_plus_noise.solve(obs_minus_samples.unsqueeze(-1)) + obs_solve = _permute_solve( + train_covar_plus_noise, obs_minus_samples.unsqueeze(-1) + ) # and multiply the test-observed matrix against the result of the solve updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1) @@ -286,3 +290,37 @@ def _draw_from_base_covar( res = covar_root.matmul(base_samples) return res.squeeze(-1) + + +def _permute_solve(A: LinearOperator, b: Tensor) -> LinearOperator: + r"""Solve the batched linear system AX = b, where b is a batched column + vector. The solve is carried out after permuting the largest batch + dimension of b to the final position, which results in a more efficient + matrix-matrix solve. + + This ideally should be handled upstream (in GPyTorch, linear_operator or + PyTorch), after which any uses of this method can be replaced with + `A.solve(b)`. + + Args: + A: LinearOperator of shape (n, n) + b: Tensor of shape (..., n, 1) + + Returns: + LinearOperator of shape (..., n, 1) + """ + # permute dimensions to move largest batch dimension to the end (more efficient + # than unsqueezing) + perm = list(range(b.ndim)) + if b.ndim > 2: + largest_batch_dim, _ = max(enumerate(b.shape[:-2]), key=lambda t: t[1]) + perm[-1], perm[largest_batch_dim] = perm[largest_batch_dim], perm[-1] + b_p = b.permute(*perm) + + x_p = A.solve(b_p) + + # Undo permutation + inverse_perm = torch.argsort(torch.tensor(perm)) + x = x_p.permute(*inverse_perm) + + return x diff --git a/test/posteriors/test_multitask.py b/test/posteriors/test_multitask.py index 42913d3ef7..1b6c6b7dbc 100644 --- a/test/posteriors/test_multitask.py +++ b/test/posteriors/test_multitask.py @@ -8,9 +8,10 @@ import torch from botorch.exceptions.errors import BotorchTensorDimensionError from botorch.models.multitask import KroneckerMultiTaskGP -from botorch.posteriors.multitask import MultitaskGPPosterior +from botorch.posteriors.multitask import _permute_solve, MultitaskGPPosterior from botorch.sampling.normal import IIDNormalSampler from botorch.utils.testing import BotorchTestCase +from linear_operator.operators import to_linear_operator def get_posterior_test_cases( @@ -41,7 +42,6 @@ def get_posterior_test_cases( class TestMultitaskGPPosterior(BotorchTestCase): - def _test_MultitaskGPPosterior(self, dtype: torch.dtype) -> None: post_list = get_posterior_test_cases(device=self.device, dtype=dtype) sample_shaping = torch.Size([5, 3]) @@ -189,3 +189,23 @@ def test_draw_from_base_covar(self): base_samples = torch.randn(4, 10, 1, device=self.device) with self.assertRaises(RuntimeError): res = posterior._draw_from_base_covar(sym_mat, base_samples) + + +class TestPermuteSolve(BotorchTestCase): + def test_permute_solve_tensor(self): + # Random PSD matrix + a = torch.randn(32, 32, device=self.device, dtype=torch.float64) + A = torch.mm(a, a.t()) + + # Random batched column vector + b = torch.randn(4, 1, 32, 1, device=self.device, dtype=torch.float64) + + # Compare results of permuted and standard solve + x_1 = _permute_solve(to_linear_operator(A), b) + x_2 = torch.linalg.solve(A, b) + self.assertAllClose(x_1, x_2) + + # Ensure also works if b is not batched + x_1 = _permute_solve(to_linear_operator(A), b[0, 0, :, :]) + x_2 = torch.linalg.solve(A, b[0, 0, :, :]) + self.assertAllClose(x_1, x_2)