Skip to content

Commit

Permalink
log_erfcx and standard_normal_(log_)hazard for stable pairwise GPs (
Browse files Browse the repository at this point in the history
#1919)

Summary:
Pull Request resolved: #1919

This commit introduces `log_erfcx`, and `standard_normal_(log_)hazard`, which improves the stability of the computations involved in pairwise Gaussian process models, and might also be of more general interest and use.

Reviewed By: Balandat

Differential Revision: D47232960

fbshipit-source-id: 4fbcdc759acb31c6b5d5197d6b1f28b4b192996a
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Jul 5, 2023
1 parent 7eb847a commit 78c0f40
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 68 deletions.
25 changes: 14 additions & 11 deletions botorch/models/likelihoods/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
from typing import Any, Tuple

import torch
from botorch.utils.probability.utils import (
log_ndtr,
log_phi,
standard_normal_log_hazard,
)
from gpytorch.likelihoods import Likelihood
from torch import Tensor
from torch.distributions import Bernoulli
Expand Down Expand Up @@ -120,16 +125,15 @@ def _calc_z(self, utility: Tensor, D: Tensor) -> Tensor:

def _calc_z_derived(self, z: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""Calculate auxiliary statistics derived from z, including log pdf,
log cdf, and the hazard function (pdf divided by cdf)"""
std_norm = torch.distributions.normal.Normal(
torch.zeros(1, dtype=z.dtype, device=z.device),
torch.ones(1, dtype=z.dtype, device=z.device),
)
z_logpdf = std_norm.log_prob(z)
z_cdf = std_norm.cdf(z)
z_logcdf = torch.log(z_cdf)
hazard = torch.exp(z_logpdf - z_logcdf)
return z_logpdf, z_logcdf, hazard
log cdf, and the hazard function (pdf divided by cdf)
Args:
z: A Tensor of arbitrary shape.
Returns:
Tensors with logpdf(z), logcdf(z), and hazard function values evaluated at -z.
"""
return log_phi(z), log_ndtr(z), standard_normal_log_hazard(-z).exp()

def p(self, utility: Tensor, D: Tensor, log: bool = False) -> Tensor:
z = self._calc_z(utility=utility, D=D)
Expand All @@ -148,7 +152,6 @@ def negative_log_gradient_sum(self, utility: Tensor, D: Tensor) -> Tensor:
_, _, h = self._calc_z_derived(z)
h_factor = h / math.sqrt(2)
grad = (h_factor.unsqueeze(-2) @ (-D)).squeeze(-2)

return grad

def negative_log_hessian_sum(self, utility: Tensor, D: Tensor) -> Tensor:
Expand Down
7 changes: 4 additions & 3 deletions botorch/models/pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,13 +411,14 @@ def _grad_posterior_f(
utility = torch.tensor(utility, dtype=self.datapoints.dtype)
prior_mean = prior_mean.cpu()

# NOTE: During the optimization, it can occur that b, p, and g_ are NaNs, though
# in the cases that occured during testing, the optimization routine escaped and
# terminated successfully without NaNs in the result.
b = self.likelihood.negative_log_gradient_sum(utility=utility, D=D)

# g_ = covar_inv x (utility - pred_prior)
p = (utility - prior_mean).unsqueeze(-1).to(covar_chol)
g_ = torch.cholesky_solve(p, covar_chol).squeeze(-1)
g = g_ + b

if ret_np:
return g.cpu().numpy()
else:
Expand Down Expand Up @@ -575,7 +576,7 @@ def _update(self, datapoints: Tensor, **kwargs) -> None:
self._x0 = x.copy() # save for warm-starting
f = torch.tensor(x, dtype=datapoints.dtype, device=datapoints.device)

# To perform hyperparameter optimization, this need to be recalculated
# To perform hyperparameter optimization, this needs to be recalculated
# when calling forward() in order to obtain correct gradients
# self.likelihood_hess is updated here is for the rare case where we
# do not want to call forward()
Expand Down
43 changes: 41 additions & 2 deletions botorch/utils/probability/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
_log_2 = math.log(2)
_sqrt_pi = math.sqrt(pi)
_inv_sqrt_pi = 1 / _sqrt_pi
_inv_sqrt_2pi = (2 * pi) ** -0.5
_neg_inv_sqrt_2 = -(2**-0.5)
_inv_sqrt_2pi = 1 / math.sqrt(2 * pi)
_inv_sqrt_2 = 1 / math.sqrt(2)
_neg_inv_sqrt_2 = -_inv_sqrt_2
_log_sqrt_2pi = math.log(2 * pi) / 2
STANDARDIZED_RANGE: Tuple[float, float] = (-1e6, 1e6)
_log_two_inv_sqrt_2pi = _log_2 - _log_sqrt_2pi # = log(2 / sqrt(2 * pi))


def case_dispatcher(
Expand Down Expand Up @@ -190,6 +192,43 @@ def log_erfc(x: Tensor) -> Tensor:
)


def log_erfcx(x: Tensor) -> Tensor:
"""Computes the logarithm of the complementary scaled error function in a
numerically stable manner. The GitHub issue tracks progress toward moving this
feature into PyTorch in C++: https://github.com/pytorch/pytorch/issues/31945.
Args:
x: An input tensor with dtype torch.float32 or torch.float64.
Returns:
A tensor of values of the same type and shape as x containing log(erfcx(x)).
"""
is_pos = x > 0
x_pos = x.masked_fill(~is_pos, 0)
x_neg = x.masked_fill(is_pos, 0)
return torch.where(
is_pos,
torch.special.erfcx(x_pos).log(),
torch.special.erfc(x_neg).log() + x.square(),
)


def standard_normal_log_hazard(x: Tensor) -> Tensor:
"""Computes the logarithm of the hazard function of the standard normal
distribution, i.e. `log(phi(x) / Phi(-x))`.
Args:
x: A tensor of any shape, with either float32 or float64 dtypes.
Returns:
A Tensor of the same shape `x`, containing the values of the logarithm of the
hazard function evaluated at `x`.
"""
# NOTE: using _inv_sqrt_2 instead of _neg_inv_sqrt_2 means we are computing Phi(-x).
a, b = get_constants_like((_log_two_inv_sqrt_2pi, _inv_sqrt_2), x)
return a - log_erfcx(b * x)


def log_prob_normal_in(a: Tensor, b: Tensor) -> Tensor:
r"""Computes the probability that a standard normal random variable takes a value
in \[a, b\], i.e. log(Phi(b) - Phi(a)), where Phi is the standard normal CDF.
Expand Down
24 changes: 16 additions & 8 deletions test/models/test_pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@

import itertools
import warnings
from typing import Dict, Tuple, Union

import torch
from botorch.acquisition.objective import ScalarizedPosteriorTransform
from botorch.exceptions import OptimizationWarning, UnsupportedError
from botorch.fit import fit_gpytorch_mll
from botorch.models.likelihoods.pairwise import (
PairwiseLikelihood,
PairwiseLogitLikelihood,
PairwiseProbitLikelihood,
)
from botorch.models.model import Model
from botorch.models.pairwise_gp import (
_ensure_psd_with_jitter,
PairwiseGP,
Expand All @@ -29,17 +32,22 @@
from gpytorch.means import ConstantMean
from gpytorch.priors import GammaPrior, SmoothedBoxPrior
from linear_operator.utils.errors import NotPSDError
from torch import Tensor


class TestPairwiseGP(BotorchTestCase):
def _make_rand_mini_data(self, batch_shape, X_dim=2, **tkwargs):
def _make_rand_mini_data(
self, batch_shape, X_dim=2, **tkwargs
) -> Tuple[Tensor, Tensor, Tensor]:
train_X = torch.rand(*batch_shape, 2, X_dim, **tkwargs)
train_Y = train_X.sum(dim=-1, keepdim=True)
train_comp = torch.topk(train_Y, k=2, dim=-2).indices.transpose(-1, -2)

return train_X, train_Y, train_comp

def _get_model_and_data(self, batch_shape, X_dim=2, likelihood_cls=None, **tkwargs):
def _get_model_and_data(
self, batch_shape, X_dim=2, likelihood_cls=None, **tkwargs
) -> Tuple[Model, Dict[str, Union[Tensor, PairwiseLikelihood]]]:
train_X, train_Y, train_comp = self._make_rand_mini_data(
batch_shape=batch_shape, X_dim=X_dim, **tkwargs
)
Expand All @@ -52,7 +60,7 @@ def _get_model_and_data(self, batch_shape, X_dim=2, likelihood_cls=None, **tkwar
model = PairwiseGP(**model_kwargs)
return model, model_kwargs

def test_pairwise_gp(self):
def test_pairwise_gp(self) -> None:
for batch_shape, dtype, likelihood_cls in itertools.product(
(torch.Size(), torch.Size([2])),
(torch.float, torch.double),
Expand Down Expand Up @@ -180,7 +188,7 @@ def test_pairwise_gp(self):
with self.assertRaises(RuntimeError):
model.set_train_data(train_X, changed_train_comp, strict=True)

def test_consolidation(self):
def test_consolidation(self) -> None:
for batch_shape, dtype, likelihood_cls in itertools.product(
(torch.Size(), torch.Size([2])),
(torch.float, torch.double),
Expand Down Expand Up @@ -240,7 +248,7 @@ def test_consolidation(self):
# Pass the original comparisons through mll should work
mll(pred, dup_comp)

def test_condition_on_observations(self):
def test_condition_on_observations(self) -> None:
for batch_shape, dtype, likelihood_cls in itertools.product(
(torch.Size(), torch.Size([2])),
(torch.float, torch.double),
Expand Down Expand Up @@ -345,7 +353,7 @@ def test_condition_on_observations(self):
)
)

def test_fantasize(self):
def test_fantasize(self) -> None:
for batch_shape, dtype, likelihood_cls in itertools.product(
(torch.Size(), torch.Size([2])),
(torch.float, torch.double),
Expand All @@ -371,7 +379,7 @@ def test_fantasize(self):
fm = model.fantasize(X=X_f, sampler=sampler, observation_noise=False)
self.assertIsInstance(fm, model.__class__)

def test_load_state_dict(self):
def test_load_state_dict(self) -> None:
model, _ = self._get_model_and_data(batch_shape=[])
sd = model.state_dict()
with self.assertRaises(UnsupportedError):
Expand All @@ -386,7 +394,7 @@ def test_load_state_dict(self):
for buffer_name in model._buffer_names:
self.assertIsNone(model.get_buffer(buffer_name))

def test_helper_functions(self):
def test_helper_functions(self) -> None:
for batch_shape, dtype in itertools.product(
(torch.Size(), torch.Size([2])), (torch.float, torch.double)
):
Expand Down
103 changes: 59 additions & 44 deletions test/utils/probability/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from botorch.utils.probability import ndtr, utils
from botorch.utils.probability.utils import (
log_erfc,
log_erfcx,
log_ndtr,
log_phi,
log_prob_normal_in,
phi,
standard_normal_log_hazard,
)
from botorch.utils.testing import BotorchTestCase
from numpy.polynomial.legendre import leggauss as numpy_leggauss
Expand Down Expand Up @@ -166,50 +168,53 @@ def test_gaussian_probabilities(self):
torch.allclose(ndtr(x), log_ndtr(x).exp(), atol=atol, rtol=rtol)
)

# test correctness of log_erfc(x) against log(erfc(x)) for positive and
# negative x
n = 16
x = torch.rand(n, dtype=dtype, device=self.device)
x = torch.cat((-x, x))
x.requires_grad = True
log_erfc_x = log_erfc(x)
special_log_erfc_x = torch.special.erfc(x).log()
self.assertTrue(
torch.allclose(log_erfc_x, special_log_erfc_x, atol=atol, rtol=rtol)
)
# testing backward passes
log_erfc_x.sum().backward()
x_grad = x.grad
x.grad[:] = 0
special_log_erfc_x.sum().backward()
special_x_grad = x.grad
self.assertTrue(
torch.allclose(x_grad, special_x_grad, atol=atol, rtol=rtol)
)

# testing robustness of log_erfc for large inputs
# large positive numbers are difficult for a naive implementation
x = torch.tensor(
[1e100 if dtype == torch.float64 else 1e10],
dtype=dtype,
device=self.device,
)
x = torch.cat((-x, x)) # looking at both tails
x.requires_grad = True
log_erfc_x = log_erfc(x)
self.assertTrue(
torch.allclose(
log_erfc_x.exp(), torch.special.erfc(x), atol=atol, rtol=rtol
)
)
self.assertFalse(log_erfc_x.isnan().any())
self.assertFalse(log_erfc_x.isinf().any())
# we can't just take the log of erfc because it will be -inf in the tail
self.assertTrue(torch.special.erfc(x).log().isinf().any())
# testing that gradients are usable floats
log_erfc_x.sum().backward()
self.assertFalse(x.grad.isnan().any())
self.assertFalse(x.grad.isinf().any())
# test correctness of log_erfc and log_erfcx
for special_f, custom_log_f in zip(
(torch.special.erfc, torch.special.erfcx), (log_erfc, log_erfcx)
):
with self.subTest(custom_log_f.__name__):
# first, testing for moderate values
n = 16
x = torch.rand(n, dtype=dtype, device=self.device)
x = torch.cat((-x, x))
x.requires_grad = True
custom_log_fx = custom_log_f(x)
special_log_fx = special_f(x).log()
self.assertAllClose(
custom_log_fx, special_log_fx, atol=atol, rtol=rtol
)
# testing backward passes
custom_log_fx.sum().backward()
x_grad = x.grad
x.grad[:] = 0
special_log_fx.sum().backward()
special_x_grad = x.grad
self.assertAllClose(x_grad, special_x_grad, atol=atol, rtol=rtol)

# testing robustness of log_erfc for large inputs
# large positive numbers are difficult for a naive implementation
x = torch.tensor(
[1e100 if dtype == torch.float64 else 1e10],
dtype=dtype,
device=self.device,
)
x = torch.cat((-x, x)) # looking at both tails
x.requires_grad = True
custom_log_fx = custom_log_f(x)
self.assertAllClose(
custom_log_fx.exp(),
special_f(x),
atol=atol,
rtol=rtol,
)
self.assertFalse(custom_log_fx.isnan().any())
self.assertFalse(custom_log_fx.isinf().any())
# we can't just take the log of erfc because the tail will be -inf
self.assertTrue(special_f(x).log().isinf().any())
# testing that gradients are usable floats
custom_log_fx.sum().backward()
self.assertFalse(x.grad.isnan().any())
self.assertFalse(x.grad.isinf().any())

# test limit behavior of log_ndtr
digits = 100 if dtype == torch.float64 else 20
Expand Down Expand Up @@ -292,6 +297,16 @@ def test_gaussian_probabilities(self):
a[2, 3] = b[2, 3]
log_prob_normal_in(a, b)

# testing gaussian hazard function
n = 16
x = torch.rand(n, dtype=dtype, device=self.device)
x = torch.cat((-x, x))
log_hx = standard_normal_log_hazard(x)
expected_log_hx = log_phi(x) - log_ndtr(-x)
self.assertAllClose(expected_log_hx, log_hx) # correctness
# NOTE: Could extend tests here similarly to log_erfc(x) tests above, but
# since the hazard functions are built on log_erfcx, not urgent.

with self.assertRaises(TypeError):
log_erfc(torch.tensor(1.0, dtype=torch.float16, device=self.device))

Expand Down

0 comments on commit 78c0f40

Please sign in to comment.