Skip to content

Commit

Permalink
Handle Batched Computation
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard2926 committed Sep 12, 2024
1 parent 02ac2c9 commit f5216c3
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions gpytorch/models/exact_prediction_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
prefix = string.ascii_lowercase[: max(fant_train_covar.dim() - self.mean_cache.dim() - 1, 0)]
ftcm = torch.einsum(prefix + "...yz,...z->" + prefix + "...y", [fant_train_covar, self.mean_cache])

targets_ = torch.reshape(targets, ftcm.shape)
fant_mean_ = torch.reshape(fant_mean, ftcm.shape)
targets_ = torch.reshape(targets, (-1, ftcm.shape[-1]))
fant_mean_ = torch.reshape(fant_mean, (-1, ftcm.shape[-1]))
small_system_rhs = targets_ - fant_mean_ - ftcm
small_system_rhs = small_system_rhs.unsqueeze(-1)
# Schur complement of a spd matrix is guaranteed to be positive definite
Expand Down

0 comments on commit f5216c3

Please sign in to comment.