Skip to content

Commit

Permalink
prior-guided acquisition function (#1920)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1920

This adds an acquisition function wrapper for prior-guided AFs.

Differential Revision: D47248296

fbshipit-source-id: 5ff20f413e3476be916829f5d58735e8ec4299fa
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jul 6, 2023
1 parent f6ed868 commit e10755a
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 2 deletions.
2 changes: 2 additions & 0 deletions botorch/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
AnalyticExpectedUtilityOfBestOption,
PairwiseBayesianActiveLearningByDisagreement,
)
from botorch.acquisition.prior_guided import PriorGuidedAcquisitionFunction
from botorch.acquisition.proximal import ProximalAcquisitionFunction
from botorch.acquisition.utils import get_acquisition_function

Expand All @@ -78,6 +79,7 @@
"PairwiseBayesianActiveLearningByDisagreement",
"PairwiseMCPosteriorVariance",
"PosteriorMean",
"PriorGuidedAcquisitionFunction",
"ProbabilityOfImprovement",
"ProximalAcquisitionFunction",
"UpperConfidenceBound",
Expand Down
79 changes: 79 additions & 0 deletions botorch/acquisition/prior_guided.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


"""
Prior-Guided Acquisition Functions
References
.. [Hvarfner2022]
C. Hvarfner, D. Stoll, A. Souza, M. Lindauer, F. Hutter, L. Nardi. PiBO:
Augmenting Acquisition Functions with User Beliefs for Bayesian Optimization.
ICLR 2022.
"""
from __future__ import annotations

from typing import Optional

from botorch.acquisition.acquisition import AcquisitionFunction
from torch import Tensor

from torch.nn import Module


class PriorGuidedAcquisitionFunction(AcquisitionFunction):
r"""Class for weighting acquisition functions by a prior distribution.
See [Hvarfner2022]_ for details.
"""

def __init__(
self,
acq_function: AcquisitionFunction,
prior_module: Module,
log: bool = False,
prior_exponent: float = 1.0,
) -> None:
r"""Initialize the prior-guided acquisition function.
Args:
acq_function: The base acquisition function.
prior_module: A Module that computes the probability
(or log probability) for the provided inputs.
log: A boolean that should be true if the acquisition function emits a
log-transformed value and the prior module emits a log probability.
prior_exponent: The exponent applied to the prior. This can be used
for example to decay the effect the prior over time as in
[Hvarfner2022]_.
"""
Module.__init__(self)
self.acq_func = acq_function
self.prior_module = prior_module
self._log = log
self._prior_exponent = prior_exponent

@property
def X_pending(self):
r"""Return the `X_pending` of the base acquisition function."""
try:
return self.acq_func.X_pending
except (ValueError, AttributeError):
raise ValueError(
f"Base acquisition function {type(self.acq_func).__name__} "
"does not have an `X_pending` attribute."
)

@X_pending.setter
def X_pending(self, X_pending: Optional[Tensor]):
r"""Sets the `X_pending` of the base acquisition function."""
self.acq_func.X_pending = X_pending

def forward(self, X: Tensor) -> Tensor:
r"""Compute the acquisition function weighted by the prior."""
if self._log:
return self.acq_func(X) + self.prior_module(X) * self._prior_exponent
return self.acq_func(X) * self.prior_module(X).pow(self._prior_exponent)
9 changes: 7 additions & 2 deletions sphinx/source/acquisition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Multi-Objective Analytic Acquisition Functions
.. automodule:: botorch.acquisition.multi_objective.analytic
:members:
:exclude-members: MultiObjectiveAnalyticAcquisitionFunction

Multi-Objective Joint Entropy Search Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.joint_entropy_search
Expand All @@ -86,7 +86,7 @@ Multi-Objective Multi-Fidelity Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.multi_fidelity
:members:

Multi-Objective Predictive Entropy Search Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.predictive_entropy_search
Expand Down Expand Up @@ -175,6 +175,11 @@ Penalized Acquisition Function Wrapper
.. automodule:: botorch.acquisition.penalized
:members:

Prior-Guided Acquisition Function Wrapper
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.prior_guided
:members:

Proximal Acquisition Function Wrapper
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.proximal
Expand Down
55 changes: 55 additions & 0 deletions test/acquisition/test_prior_guided.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from itertools import product

import torch
from botorch.acquisition.monte_carlo import qExpectedImprovement
from botorch.acquisition.prior_guided import PriorGuidedAcquisitionFunction
from botorch.models import SingleTaskGP
from botorch.utils.testing import BotorchTestCase
from torch.nn import Module


class DummyPrior(Module):
def forward(self, X):
p = torch.distributions.Normal(0, 1)
# sum over d and q dimensions
return p.log_prob(X).sum(dim=-1).sum(dim=-1).exp()


class TestPriorGuidedAcquisitionFunction(BotorchTestCase):
def test_prior_guided_acquisition_function(self):
prior = DummyPrior()
for dtype in (torch.float, torch.double):
train_X = torch.rand(5, 3, dtype=dtype, device=self.device)
train_Y = train_X.norm(dim=-1, keepdim=True)
model = SingleTaskGP(train_X, train_Y).eval()
qEI = qExpectedImprovement(model, best_f=0.0)
for batch_shape, q, use_log, exponent in product(
([], [2]), (1, 2), (False, True), (1.0, 2.0)
):
af = PriorGuidedAcquisitionFunction(
acq_function=qEI,
prior_module=prior,
log=use_log,
prior_exponent=exponent,
)
test_X = torch.rand(*batch_shape, q, 3, dtype=dtype, device=self.device)
with torch.no_grad():
val = af(test_X)
prob = prior(test_X)
ei = qEI(test_X)
if use_log:
expected_val = prob * exponent + ei
else:
expected_val = prob.pow(exponent) * ei
self.assertTrue(torch.equal(val, expected_val))
# test set_X_pending
X_pending = torch.rand(2, 3, dtype=dtype, device=self.device)
af.X_pending = X_pending
self.assertTrue(torch.equal(X_pending, af.acq_func.X_pending))
self.assertTrue(torch.equal(X_pending, af.X_pending))

0 comments on commit e10755a

Please sign in to comment.