From 8a7923acaa402567a7d10543aefed30461a432ac Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Thu, 13 Jul 2023 08:02:46 -0700 Subject: [PATCH] Enabling `SampleReducingMCAcquisitionFunctions` with constraints in `input_constructors` (#1932) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1932 This commit enables the prinicpled constraint treatment via `SampleReducingMCAcquisitionFunctions` through `input_constructors` and `get_acquisition_function`. Differential Revision: https://internalfb.com/D47365084 fbshipit-source-id: 6422ddeabde2cd9804b1e76e25236e2021f8c1a2 --- botorch/acquisition/input_constructors.py | 92 +++++++++++--- botorch/acquisition/utils.py | 21 +++- test/acquisition/test_input_constructors.py | 23 +++- test/acquisition/test_utils.py | 131 +++++++++++++++++++- 4 files changed, 245 insertions(+), 22 deletions(-) diff --git a/botorch/acquisition/input_constructors.py b/botorch/acquisition/input_constructors.py index 6e6f7dcd62..f4ac50fd9c 100644 --- a/botorch/acquisition/input_constructors.py +++ b/botorch/acquisition/input_constructors.py @@ -78,6 +78,7 @@ from botorch.acquisition.preference import AnalyticExpectedUtilityOfBestOption from botorch.acquisition.risk_measures import RiskMeasureMCObjective from botorch.acquisition.utils import ( + compute_best_feasible_objective, expand_trace_observations, get_optimal_samples, project_to_target_fidelity, @@ -457,7 +458,9 @@ def construct_inputs_qEI( X_pending: Optional[Tensor] = None, sampler: Optional[MCSampler] = None, best_f: Optional[Union[float, Tensor]] = None, - **kwargs: Any, + constraints: Optional[List[Callable[[Tensor], Tensor]]] = None, + eta: Union[Tensor, float] = 1e-3, + **ignored: Any, ) -> Dict[str, Any]: r"""Construct kwargs for the `qExpectedImprovement` constructor. @@ -473,7 +476,15 @@ def construct_inputs_qEI( sampler: The sampler used to draw base samples. If omitted, uses the acquisition functions's default sampler. best_f: Threshold above (or below) which improvement is defined. - kwargs: Not used. + constraints: A list of constraint callables which map a Tensor of posterior + samples of dimension `sample_shape x batch-shape x q x m`-dim to a + `sample_shape x batch-shape x q`-dim Tensor. The associated constraints + are considered satisfied if the output is less than zero. + eta: Temperature parameter(s) governing the smoothness of the sigmoid + approximation to the constraint indicators. For more details, on this + parameter, see the docs of `compute_smoothed_constraint_indicator`. + ignored: Not used. + Returns: A dict mapping kwarg names of the constructor to values. """ @@ -489,9 +500,11 @@ def construct_inputs_qEI( training_data=training_data, objective=objective, posterior_transform=posterior_transform, + constraints=constraints, + model=model, ) - return {**base_inputs, "best_f": best_f} + return {**base_inputs, "best_f": best_f, "constraints": constraints, "eta": eta} @acqf_input_constructor(qNoisyExpectedImprovement) @@ -505,7 +518,9 @@ def construct_inputs_qNEI( X_baseline: Optional[Tensor] = None, prune_baseline: Optional[bool] = True, cache_root: Optional[bool] = True, - **kwargs: Any, + constraints: Optional[List[Callable[[Tensor], Tensor]]] = None, + eta: Union[Tensor, float] = 1e-3, + **ignored: Any, ) -> Dict[str, Any]: r"""Construct kwargs for the `qNoisyExpectedImprovement` constructor. @@ -527,7 +542,14 @@ def construct_inputs_qNEI( prune_baseline: If True, remove points in `X_baseline` that are highly unlikely to be the best point. This can significantly improve performance and is generally recommended. - kwargs: Not used. + constraints: A list of constraint callables which map a Tensor of posterior + samples of dimension `sample_shape x batch-shape x q x m`-dim to a + `sample_shape x batch-shape x q`-dim Tensor. The associated constraints + are considered satisfied if the output is less than zero. + eta: Temperature parameter(s) governing the smoothness of the sigmoid + approximation to the constraint indicators. For more details, on this + parameter, see the docs of `compute_smoothed_constraint_indicator`. + ignored: Not used. Returns: A dict mapping kwarg names of the constructor to values. @@ -553,6 +575,8 @@ def construct_inputs_qNEI( "X_baseline": X_baseline, "prune_baseline": prune_baseline, "cache_root": cache_root, + "constraints": constraints, + "eta": eta, } @@ -566,7 +590,9 @@ def construct_inputs_qPI( sampler: Optional[MCSampler] = None, tau: float = 1e-3, best_f: Optional[Union[float, Tensor]] = None, - **kwargs: Any, + constraints: Optional[List[Callable[[Tensor], Tensor]]] = None, + eta: Union[Tensor, float] = 1e-3, + **ignored: Any, ) -> Dict[str, Any]: r"""Construct kwargs for the `qProbabilityOfImprovement` constructor. @@ -588,13 +614,26 @@ def construct_inputs_qPI( best_f: The best objective value observed so far (assumed noiseless). Can be a `batch_shape`-shaped tensor, which in case of a batched model specifies potentially different values for each element of the batch. - kwargs: Not used. + constraints: A list of constraint callables which map a Tensor of posterior + samples of dimension `sample_shape x batch-shape x q x m`-dim to a + `sample_shape x batch-shape x q`-dim Tensor. The associated constraints + are considered satisfied if the output is less than zero. + eta: Temperature parameter(s) governing the smoothness of the sigmoid + approximation to the constraint indicators. For more details, on this + parameter, see the docs of `compute_smoothed_constraint_indicator`. + ignored: Not used. + Returns: A dict mapping kwarg names of the constructor to values. """ if best_f is None: - best_f = get_best_f_mc(training_data=training_data, objective=objective) - + best_f = get_best_f_mc( + training_data=training_data, + objective=objective, + posterior_transform=posterior_transform, + constraints=constraints, + model=model, + ) base_inputs = _construct_inputs_mc_base( model=model, objective=objective, @@ -603,7 +642,13 @@ def construct_inputs_qPI( X_pending=X_pending, ) - return {**base_inputs, "tau": tau, "best_f": best_f} + return { + **base_inputs, + "tau": tau, + "best_f": best_f, + "constraints": constraints, + "eta": eta, + } @acqf_input_constructor(qUpperConfidenceBound) @@ -615,7 +660,7 @@ def construct_inputs_qUCB( X_pending: Optional[Tensor] = None, sampler: Optional[MCSampler] = None, beta: float = 0.2, - **kwargs: Any, + **ignored: Any, ) -> Dict[str, Any]: r"""Construct kwargs for the `qUpperConfidenceBound` constructor. @@ -631,7 +676,7 @@ def construct_inputs_qUCB( sampler: The sampler used to draw base samples. If omitted, uses the acquisition functions's default sampler. beta: Controls tradeoff between mean and standard deviation in UCB. - kwargs: Not used. + ignored: Not used. Returns: A dict mapping kwarg names of the constructor to values. @@ -1083,18 +1128,28 @@ def get_best_f_mc( training_data: MaybeDict[SupervisedDataset], objective: Optional[MCAcquisitionObjective] = None, posterior_transform: Optional[PosteriorTransform] = None, + constraints: Optional[List[Callable[[Tensor], Tensor]]] = None, + model: Optional[Model] = None, ) -> Tensor: if isinstance(training_data, dict) and not _field_is_shared( training_data, fieldname="X" ): raise NotImplementedError("Currently only block designs are supported.") + X_baseline = _get_dataset_field( + training_data, + fieldname="X", + transform=lambda field: field(), + assert_shared=True, + first_only=True, + ) + Y = _get_dataset_field( training_data, fieldname="Y", transform=lambda field: field(), join_rule=lambda field_tensors: torch.cat(field_tensors, dim=-1), - ) + ) # batch_shape x n x d if posterior_transform is not None: # retain the original tensor dimension since objective expects explicit @@ -1111,7 +1166,16 @@ def get_best_f_mc( "acquisition functions)." ) objective = IdentityMCObjective() - return objective(Y).max(-1).values + obj = objective(Y, X=X_baseline) # batch_shape x n + return compute_best_feasible_objective( + samples=Y, + obj=obj, + constraints=constraints, + model=model, + objective=objective, + posterior_transform=posterior_transform, + X_baseline=X_baseline, + ) def optimize_objective( diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 9a1f4eab8b..a683896840 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -101,10 +101,19 @@ def get_acquisition_function( ) # instantiate and return the requested acquisition function if acquisition_function_name in ("qEI", "qPI"): - obj = objective( - model.posterior(X_observed, posterior_transform=posterior_transform).mean + # Since these are the non-noisy variants, use the posterior mean at the observed + # inputs directly to compute the best feasible value without sampling. + Y = model.posterior(X_observed, posterior_transform=posterior_transform).mean + obj = objective(samples=Y, X=X_observed) + best_f = compute_best_feasible_objective( + samples=Y, + obj=obj, + constraints=constraints, + model=model, + objective=objective, + posterior_transform=posterior_transform, + X_baseline=X_observed, ) - best_f = obj.max(dim=-1).values if acquisition_function_name == "qEI": return monte_carlo.qExpectedImprovement( model=model, @@ -113,6 +122,8 @@ def get_acquisition_function( objective=objective, posterior_transform=posterior_transform, X_pending=X_pending, + constraints=constraints, + eta=eta, ) elif acquisition_function_name == "qPI": return monte_carlo.qProbabilityOfImprovement( @@ -123,6 +134,8 @@ def get_acquisition_function( posterior_transform=posterior_transform, X_pending=X_pending, tau=kwargs.get("tau", 1e-3), + constraints=constraints, + eta=eta, ) elif acquisition_function_name == "qNEI": return monte_carlo.qNoisyExpectedImprovement( @@ -135,6 +148,8 @@ def get_acquisition_function( prune_baseline=kwargs.get("prune_baseline", True), marginalize_dim=kwargs.get("marginalize_dim"), cache_root=kwargs.get("cache_root", True), + constraints=constraints, + eta=eta, ) elif acquisition_function_name == "qSR": return monte_carlo.qSimpleRegret( diff --git a/test/acquisition/test_input_constructors.py b/test/acquisition/test_input_constructors.py index dce2e5e5e7..e2a59d34ad 100644 --- a/test/acquisition/test_input_constructors.py +++ b/test/acquisition/test_input_constructors.py @@ -139,13 +139,13 @@ def test_get_best_f_mc(self): best_f = get_best_f_mc(training_data=self.blockX_multiY, objective=obj) multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1) - best_f_expected = (multi_Y @ obj.weights).max() + best_f_expected = (multi_Y @ obj.weights).amax(dim=-1, keepdim=True) self.assertEqual(best_f, best_f_expected) post_tf = ScalarizedPosteriorTransform(weights=torch.ones(2)) best_f = get_best_f_mc( training_data=self.blockX_multiY, posterior_transform=post_tf ) - best_f_expected = (multi_Y.sum(dim=-1)).max() + best_f_expected = (multi_Y.sum(dim=-1)).amax(dim=-1, keepdim=True) self.assertEqual(best_f, best_f_expected) @mock.patch("botorch.acquisition.input_constructors.optimize_acqf") @@ -350,6 +350,9 @@ def test_construct_inputs_qEI(self): self.assertIsNone(kwargs["objective"]) self.assertIsNone(kwargs["X_pending"]) self.assertIsNone(kwargs["sampler"]) + self.assertIsNone(kwargs["constraints"]) + self.assertIsInstance(kwargs["eta"], float) + self.assertTrue(kwargs["eta"] < 1) X_pending = torch.rand(2, 2) objective = LinearMCObjective(torch.rand(2)) kwargs = c( @@ -362,6 +365,9 @@ def test_construct_inputs_qEI(self): self.assertTrue(torch.equal(kwargs["objective"].weights, objective.weights)) self.assertTrue(torch.equal(kwargs["X_pending"], X_pending)) self.assertIsNone(kwargs["sampler"]) + self.assertIsNone(kwargs["constraints"]) + self.assertIsInstance(kwargs["eta"], float) + self.assertTrue(kwargs["eta"] < 1) multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1) best_f_expected = objective(multi_Y).max() self.assertEqual(kwargs["best_f"], best_f_expected) @@ -386,6 +392,10 @@ def test_construct_inputs_qNEI(self): self.assertIsNone(kwargs["sampler"]) self.assertTrue(kwargs["prune_baseline"]) self.assertTrue(torch.equal(kwargs["X_baseline"], self.blockX_blockY[0].X())) + self.assertIsNone(kwargs["constraints"]) + self.assertIsInstance(kwargs["eta"], float) + self.assertTrue(kwargs["eta"] < 1) + with self.assertRaisesRegex(ValueError, "Field `X` must be shared"): c(model=mock_model, training_data=self.multiX_multiY) X_baseline = torch.rand(2, 2) @@ -401,6 +411,9 @@ def test_construct_inputs_qNEI(self): self.assertIsNone(kwargs["sampler"]) self.assertFalse(kwargs["prune_baseline"]) self.assertTrue(torch.equal(kwargs["X_baseline"], X_baseline)) + self.assertIsNone(kwargs["constraints"]) + self.assertIsInstance(kwargs["eta"], float) + self.assertTrue(kwargs["eta"] < 1) def test_construct_inputs_qPI(self): c = get_acqf_input_constructor(qProbabilityOfImprovement) @@ -411,6 +424,9 @@ def test_construct_inputs_qPI(self): self.assertIsNone(kwargs["X_pending"]) self.assertIsNone(kwargs["sampler"]) self.assertEqual(kwargs["tau"], 1e-3) + self.assertIsNone(kwargs["constraints"]) + self.assertIsInstance(kwargs["eta"], float) + self.assertTrue(kwargs["eta"] < 1) X_pending = torch.rand(2, 2) objective = LinearMCObjective(torch.rand(2)) kwargs = c( @@ -425,6 +441,9 @@ def test_construct_inputs_qPI(self): self.assertTrue(torch.equal(kwargs["X_pending"], X_pending)) self.assertIsNone(kwargs["sampler"]) self.assertEqual(kwargs["tau"], 1e-2) + self.assertIsNone(kwargs["constraints"]) + self.assertIsInstance(kwargs["eta"], float) + self.assertTrue(kwargs["eta"] < 1) multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1) best_f_expected = objective(multi_Y).max() self.assertEqual(kwargs["best_f"], best_f_expected) diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index d0b5410d7e..4fdcf33487 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import itertools +import math from unittest import mock import torch @@ -60,12 +61,15 @@ def setUp(self): self.qmc = True self.ref_point = [0.0, 0.0] self.mo_objective = DummyMCMultiOutputObjective() - self.Y = torch.tensor([[1.0, 2.0]]) + self.Y = torch.tensor([[1.0, 2.0]]) # (2 x 1)-dim multi-objective outcomes self.seed = 1 @mock.patch(f"{monte_carlo.__name__}.qExpectedImprovement") def test_GetQEI(self, mock_acqf): - self.model = MockModel(MockPosterior(mean=torch.zeros(1, 2))) + n = len(self.X_observed) + mean = torch.arange(n, dtype=torch.double).view(-1, 1) + var = torch.ones_like(mean) + self.model = MockModel(MockPosterior(mean=mean, variance=var)) acqf = get_acquisition_function( acquisition_function_name="qEI", model=self.model, @@ -85,6 +89,8 @@ def test_GetQEI(self, mock_acqf): objective=self.objective, posterior_transform=None, X_pending=self.X_pending, + constraints=None, + eta=1e-3, ) # test batched model self.model = MockModel(MockPosterior(mean=torch.zeros(1, 2, 1))) @@ -125,10 +131,50 @@ def test_GetQEI(self, mock_acqf): ) self.assertEqual(mock_acqf.call_args[-1]["best_f"].item(), -1.0) + # with constraints + upper_bound = self.Y[0, 0] + 1 / 2 # = 1.5 + constraints = [lambda samples: samples[..., 0] - upper_bound] + eta = math.pi * 1e-2 # testing non-standard eta + + acqf = get_acquisition_function( + acquisition_function_name="qEI", + model=self.model, + objective=self.objective, + X_observed=self.X_observed, + X_pending=self.X_pending, + mc_samples=self.mc_samples, + seed=self.seed, + marginalize_dim=0, + constraints=constraints, + eta=eta, + ) + self.assertEqual(acqf, mock_acqf.return_value) + best_feasible_f = compute_best_feasible_objective( + samples=mean, + obj=self.objective(mean), + constraints=constraints, + model=self.model, + objective=self.objective, + X_baseline=self.X_observed, + ) + mock_acqf.assert_called_with( + model=self.model, + best_f=best_feasible_f, + sampler=mock.ANY, + objective=self.objective, + posterior_transform=None, + X_pending=self.X_pending, + constraints=constraints, + eta=eta, + ) + @mock.patch(f"{monte_carlo.__name__}.qProbabilityOfImprovement") def test_GetQPI(self, mock_acqf): # basic test - self.model = MockModel(MockPosterior(mean=torch.zeros(1, 2))) + n = len(self.X_observed) + mean = torch.arange(n, dtype=torch.double).view(-1, 1) + var = torch.ones_like(mean) + self.model = MockModel(MockPosterior(mean=mean, variance=var)) acqf = get_acquisition_function( acquisition_function_name="qPI", model=self.model, @@ -148,6 +194,8 @@ def test_GetQPI(self, mock_acqf): posterior_transform=None, X_pending=self.X_pending, tau=1e-3, + constraints=None, + eta=1e-3, ) args, kwargs = mock_acqf.call_args self.assertEqual(args, ()) @@ -197,9 +245,54 @@ def test_GetQPI(self, mock_acqf): ) self.assertTrue(acqf == mock_acqf.return_value) + # with constraints + n = len(self.X_observed) + mean = torch.arange(n, dtype=torch.double).view(-1, 1) + var = torch.ones_like(mean) + self.model = MockModel(MockPosterior(mean=mean, variance=var)) + upper_bound = self.Y[0, 0] + 1 / 2 # = 1.5 + constraints = [lambda samples: samples[..., 0] - upper_bound] + eta = math.pi * 1e-2 # testing non-standard eta + acqf = get_acquisition_function( + acquisition_function_name="qPI", + model=self.model, + objective=self.objective, + X_observed=self.X_observed, + X_pending=self.X_pending, + mc_samples=self.mc_samples, + seed=self.seed, + marginalize_dim=0, + constraints=constraints, + eta=eta, + ) + self.assertEqual(acqf, mock_acqf.return_value) + best_feasible_f = compute_best_feasible_objective( + samples=mean, + obj=self.objective(mean), + constraints=constraints, + model=self.model, + objective=self.objective, + X_baseline=self.X_observed, + ) + mock_acqf.assert_called_with( + model=self.model, + best_f=best_feasible_f, + sampler=mock.ANY, + objective=self.objective, + posterior_transform=None, + X_pending=self.X_pending, + tau=1e-3, + constraints=constraints, + eta=eta, + ) + @mock.patch(f"{monte_carlo.__name__}.qNoisyExpectedImprovement") def test_GetQNEI(self, mock_acqf): # basic test + n = len(self.X_observed) + mean = torch.arange(n, dtype=torch.double).view(-1, 1) + var = torch.ones_like(mean) + self.model = MockModel(MockPosterior(mean=mean, variance=var)) acqf = get_acquisition_function( acquisition_function_name="qNEI", model=self.model, @@ -257,6 +350,38 @@ def test_GetQNEI(self, mock_acqf): self.assertEqual(sampler.seed, 2) self.assertTrue(torch.equal(kwargs["X_baseline"], self.X_observed)) + # with constraints + upper_bound = self.Y[0, 0] + 1 / 2 # = 1.5 + constraints = [lambda samples: samples[..., 0] - upper_bound] + eta = math.pi * 1e-2 # testing non-standard eta + + acqf = get_acquisition_function( + acquisition_function_name="qNEI", + model=self.model, + objective=self.objective, + X_observed=self.X_observed, + X_pending=self.X_pending, + mc_samples=self.mc_samples, + seed=self.seed, + marginalize_dim=0, + constraints=constraints, + eta=eta, + ) + self.assertEqual(acqf, mock_acqf.return_value) + mock_acqf.assert_called_with( + model=self.model, + X_baseline=self.X_observed, + sampler=mock.ANY, + objective=self.objective, + posterior_transform=None, + X_pending=self.X_pending, + prune_baseline=True, + marginalize_dim=0, + cache_root=True, + constraints=constraints, + eta=eta, + ) + @mock.patch(f"{monte_carlo.__name__}.qSimpleRegret") def test_GetQSR(self, mock_acqf): # basic test