From 1c315ba6f6c0e6610e414a653b003c9d40508984 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Wed, 5 Jul 2023 11:05:03 -0700 Subject: [PATCH] `einsum` for `LCEAKernel` (#1918) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1918 This commit fixes the `einsum` indexing in the forward pass of the `LCEAKernel`. Reviewed By: sdaulton Differential Revision: D47226145 fbshipit-source-id: f92c519759cdb43513b0e472701e31223be8242a --- botorch/models/kernels/contextual_lcea.py | 47 +++++++++++++++-------- test/models/kernels/test_contextual.py | 29 +++++++++++--- 2 files changed, 53 insertions(+), 23 deletions(-) diff --git a/botorch/models/kernels/contextual_lcea.py b/botorch/models/kernels/contextual_lcea.py index b8c960ac26..3f3a0cb312 100644 --- a/botorch/models/kernels/contextual_lcea.py +++ b/botorch/models/kernels/contextual_lcea.py @@ -385,19 +385,38 @@ def forward( covariance matrices together """ # context covar matrix - if not self.training: - context_covar = self._context_covar - else: - context_covar = self._eval_context_covar() + context_covar = ( + self._eval_context_covar() if self.training else self._context_covar + ) + base_covar_perm = self._eval_base_covar_perm(x1, x2) + # expand context_covar to match base_covar_perm + if base_covar_perm.dim() > context_covar.dim(): + context_covar = context_covar.expand(base_covar_perm.shape) + # then weight by the context kernel + # compute the base kernel on the d parameters + einsum_str = "...nnki, ...nnki -> ...n" if diag else "...ki, ...ki -> ..." + covar_dense = torch.einsum(einsum_str, context_covar, base_covar_perm) + if diag: + return DiagLinearOperator(covar_dense) + return DenseLinearOperator(covar_dense) + + def _eval_base_covar_perm(self, x1: Tensor, x2: Tensor) -> Tensor: + """Computes the base covariance matrix on x1, x2, applying permutations and + reshaping the kernel matrix as required by `forward`. + + NOTE: Using the notation n = num_observations, k = num_contexts, d = input_dim, + the input tensors have to have the following shapes. + + Args: + x1: `batch_shape x n x (k*d)`-dim Tensor of kernel inputs. + x2: `batch_shape x n x (k*d)`-dim Tensor of kernel inputs. + + Returns: + `batch_shape x n x n x k x k`-dim Tensor of base covariance values. + """ if self.permutation is not None: x1 = x1[..., self.permutation] x2 = x2[..., self.permutation] - # check input batch size if b x ns x n x d: expand context_covar to - # b x ns x num_context x num_context - if x1.dim() > context_covar.dim(): - context_covar = context_covar.expand( - x1.shape[:-1] + torch.Size([x2.shape[-2]]) + context_covar.shape - ) # turn last two dimensions of n x (k*d) into (n*k) x d. x1_exp = x1.reshape(*x1.shape[:-2], -1, self.num_param) x2_exp = x2.reshape(*x2.shape[:-2], -1, self.num_param) @@ -417,10 +436,4 @@ def forward( .view(view_shape) .permute(*list(range(x1.ndim - 2)), -4, -2, -3, -1) ) - # then weight by the context kernel - # compute the base kernel on the d parameters - einsum_str = "...kk, ...nnkk -> ...n" if diag else "...kk, ...kk -> ..." - covar_dense = torch.einsum(einsum_str, context_covar, base_covar_perm) - if diag: - return DiagLinearOperator(covar_dense) - return DenseLinearOperator(covar_dense) + return base_covar_perm diff --git a/test/models/kernels/test_contextual.py b/test/models/kernels/test_contextual.py index 8f8a710df7..7257e4b092 100644 --- a/test/models/kernels/test_contextual.py +++ b/test/models/kernels/test_contextual.py @@ -83,17 +83,34 @@ def testLCEAKernel(self): self.assertEqual(kernel.outputscale_list.shape, torch.Size([num_contexts])) # test diag works well for lazy tensor - x1 = torch.rand(5, 4) - x2 = torch.rand(5, 4) + num_obs, num_contexts, input_dim = 5, 2, 2 + x1 = torch.rand(num_obs, num_contexts * input_dim) + x2 = torch.rand(num_obs, num_contexts * input_dim) res = kernel(x1, x2).to_dense() res_diag = kernel(x1, x2, diag=True) - self.assertLess(torch.norm(res_diag - res.diag()), 1e-4) + self.assertAllClose(res_diag, res.diag(), atol=1e-4) # test batch evaluation - x1 = torch.rand(3, 5, 4) - x2 = torch.rand(3, 5, 4) + batch_dim = 3 + x1 = torch.rand(batch_dim, num_obs, num_contexts * input_dim) + x2 = torch.rand(batch_dim, num_obs, num_contexts * input_dim) res = kernel(x1, x2).to_dense() - self.assertEqual(res.shape, torch.Size([3, 5, 5])) + self.assertEqual(res.shape, torch.Size([batch_dim, num_obs, num_obs])) + + # testing efficient `einsum` with naive `sum` implementation + context_covar = kernel._eval_context_covar() + if x1.dim() > context_covar.dim(): + context_covar = context_covar.expand( + x1.shape[:-1] + torch.Size([x2.shape[-2]]) + context_covar.shape + ) + base_covar_perm = kernel._eval_base_covar_perm(x1, x2) + expected_res = (context_covar * base_covar_perm).sum(dim=-2).sum(dim=-1) + self.assertAllClose(expected_res, res) + + # diagonal batch evaluation + res_diag = kernel(x1, x2, diag=True).to_dense() + expected_res_diag = torch.diagonal(expected_res, dim1=-1, dim2=-2) + self.assertAllClose(expected_res_diag, res_diag) # test input context_weight, # test input embs_dim_list (one categorical feature)