Skip to content

Commit

Permalink
einsum for LCEAKernel (pytorch#1918)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1918

This commit fixes the `einsum` indexing in the forward pass of the `LCEAKernel`.

Reviewed By: sdaulton

Differential Revision: D47226145

fbshipit-source-id: f92c519759cdb43513b0e472701e31223be8242a
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Jul 5, 2023
1 parent 7eb847a commit 1c315ba
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 23 deletions.
47 changes: 30 additions & 17 deletions botorch/models/kernels/contextual_lcea.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,19 +385,38 @@ def forward(
covariance matrices together
"""
# context covar matrix
if not self.training:
context_covar = self._context_covar
else:
context_covar = self._eval_context_covar()
context_covar = (
self._eval_context_covar() if self.training else self._context_covar
)
base_covar_perm = self._eval_base_covar_perm(x1, x2)
# expand context_covar to match base_covar_perm
if base_covar_perm.dim() > context_covar.dim():
context_covar = context_covar.expand(base_covar_perm.shape)
# then weight by the context kernel
# compute the base kernel on the d parameters
einsum_str = "...nnki, ...nnki -> ...n" if diag else "...ki, ...ki -> ..."
covar_dense = torch.einsum(einsum_str, context_covar, base_covar_perm)
if diag:
return DiagLinearOperator(covar_dense)
return DenseLinearOperator(covar_dense)

def _eval_base_covar_perm(self, x1: Tensor, x2: Tensor) -> Tensor:
"""Computes the base covariance matrix on x1, x2, applying permutations and
reshaping the kernel matrix as required by `forward`.
NOTE: Using the notation n = num_observations, k = num_contexts, d = input_dim,
the input tensors have to have the following shapes.
Args:
x1: `batch_shape x n x (k*d)`-dim Tensor of kernel inputs.
x2: `batch_shape x n x (k*d)`-dim Tensor of kernel inputs.
Returns:
`batch_shape x n x n x k x k`-dim Tensor of base covariance values.
"""
if self.permutation is not None:
x1 = x1[..., self.permutation]
x2 = x2[..., self.permutation]
# check input batch size if b x ns x n x d: expand context_covar to
# b x ns x num_context x num_context
if x1.dim() > context_covar.dim():
context_covar = context_covar.expand(
x1.shape[:-1] + torch.Size([x2.shape[-2]]) + context_covar.shape
)
# turn last two dimensions of n x (k*d) into (n*k) x d.
x1_exp = x1.reshape(*x1.shape[:-2], -1, self.num_param)
x2_exp = x2.reshape(*x2.shape[:-2], -1, self.num_param)
Expand All @@ -417,10 +436,4 @@ def forward(
.view(view_shape)
.permute(*list(range(x1.ndim - 2)), -4, -2, -3, -1)
)
# then weight by the context kernel
# compute the base kernel on the d parameters
einsum_str = "...kk, ...nnkk -> ...n" if diag else "...kk, ...kk -> ..."
covar_dense = torch.einsum(einsum_str, context_covar, base_covar_perm)
if diag:
return DiagLinearOperator(covar_dense)
return DenseLinearOperator(covar_dense)
return base_covar_perm
29 changes: 23 additions & 6 deletions test/models/kernels/test_contextual.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,34 @@ def testLCEAKernel(self):
self.assertEqual(kernel.outputscale_list.shape, torch.Size([num_contexts]))

# test diag works well for lazy tensor
x1 = torch.rand(5, 4)
x2 = torch.rand(5, 4)
num_obs, num_contexts, input_dim = 5, 2, 2
x1 = torch.rand(num_obs, num_contexts * input_dim)
x2 = torch.rand(num_obs, num_contexts * input_dim)
res = kernel(x1, x2).to_dense()
res_diag = kernel(x1, x2, diag=True)
self.assertLess(torch.norm(res_diag - res.diag()), 1e-4)
self.assertAllClose(res_diag, res.diag(), atol=1e-4)

# test batch evaluation
x1 = torch.rand(3, 5, 4)
x2 = torch.rand(3, 5, 4)
batch_dim = 3
x1 = torch.rand(batch_dim, num_obs, num_contexts * input_dim)
x2 = torch.rand(batch_dim, num_obs, num_contexts * input_dim)
res = kernel(x1, x2).to_dense()
self.assertEqual(res.shape, torch.Size([3, 5, 5]))
self.assertEqual(res.shape, torch.Size([batch_dim, num_obs, num_obs]))

# testing efficient `einsum` with naive `sum` implementation
context_covar = kernel._eval_context_covar()
if x1.dim() > context_covar.dim():
context_covar = context_covar.expand(
x1.shape[:-1] + torch.Size([x2.shape[-2]]) + context_covar.shape
)
base_covar_perm = kernel._eval_base_covar_perm(x1, x2)
expected_res = (context_covar * base_covar_perm).sum(dim=-2).sum(dim=-1)
self.assertAllClose(expected_res, res)

# diagonal batch evaluation
res_diag = kernel(x1, x2, diag=True).to_dense()
expected_res_diag = torch.diagonal(expected_res, dim1=-1, dim2=-2)
self.assertAllClose(expected_res_diag, res_diag)

# test input context_weight,
# test input embs_dim_list (one categorical feature)
Expand Down

0 comments on commit 1c315ba

Please sign in to comment.