diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index abd868ba7..976f0a5ea 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -193,8 +193,8 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ prefix = string.ascii_lowercase[: max(fant_train_covar.dim() - self.mean_cache.dim() - 1, 0)] ftcm = torch.einsum(prefix + "...yz,...z->" + prefix + "...y", [fant_train_covar, self.mean_cache]) - targets_ = torch.reshape(targets, ftcm.shape) - fant_mean_ = torch.reshape(fant_mean, ftcm.shape) + targets_ = torch.reshape(targets, (-1, ftcm.shape[-1])) + fant_mean_ = torch.reshape(fant_mean, (-1, ftcm.shape[-1])) small_system_rhs = targets_ - fant_mean_ - ftcm small_system_rhs = small_system_rhs.unsqueeze(-1) # Schur complement of a spd matrix is guaranteed to be positive definite