Skip to content

Commit

Permalink
Enabling SampleReducingMCAcquisitionFunctions with constraints in `…
Browse files Browse the repository at this point in the history
…input_constructors` (#1932)

Summary:
Pull Request resolved: #1932

This commit enables the prinicpled constraint treatment via `SampleReducingMCAcquisitionFunctions` through `input_constructors` and `get_acquisition_function`.

Differential Revision: https://internalfb.com/D47365084

fbshipit-source-id: 6422ddeabde2cd9804b1e76e25236e2021f8c1a2
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Jul 13, 2023
1 parent a8fa242 commit 8a7923a
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 22 deletions.
92 changes: 78 additions & 14 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from botorch.acquisition.preference import AnalyticExpectedUtilityOfBestOption
from botorch.acquisition.risk_measures import RiskMeasureMCObjective
from botorch.acquisition.utils import (
compute_best_feasible_objective,
expand_trace_observations,
get_optimal_samples,
project_to_target_fidelity,
Expand Down Expand Up @@ -457,7 +458,9 @@ def construct_inputs_qEI(
X_pending: Optional[Tensor] = None,
sampler: Optional[MCSampler] = None,
best_f: Optional[Union[float, Tensor]] = None,
**kwargs: Any,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
eta: Union[Tensor, float] = 1e-3,
**ignored: Any,
) -> Dict[str, Any]:
r"""Construct kwargs for the `qExpectedImprovement` constructor.
Expand All @@ -473,7 +476,15 @@ def construct_inputs_qEI(
sampler: The sampler used to draw base samples. If omitted, uses
the acquisition functions's default sampler.
best_f: Threshold above (or below) which improvement is defined.
kwargs: Not used.
constraints: A list of constraint callables which map a Tensor of posterior
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
are considered satisfied if the output is less than zero.
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. For more details, on this
parameter, see the docs of `compute_smoothed_constraint_indicator`.
ignored: Not used.
Returns:
A dict mapping kwarg names of the constructor to values.
"""
Expand All @@ -489,9 +500,11 @@ def construct_inputs_qEI(
training_data=training_data,
objective=objective,
posterior_transform=posterior_transform,
constraints=constraints,
model=model,
)

return {**base_inputs, "best_f": best_f}
return {**base_inputs, "best_f": best_f, "constraints": constraints, "eta": eta}


@acqf_input_constructor(qNoisyExpectedImprovement)
Expand All @@ -505,7 +518,9 @@ def construct_inputs_qNEI(
X_baseline: Optional[Tensor] = None,
prune_baseline: Optional[bool] = True,
cache_root: Optional[bool] = True,
**kwargs: Any,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
eta: Union[Tensor, float] = 1e-3,
**ignored: Any,
) -> Dict[str, Any]:
r"""Construct kwargs for the `qNoisyExpectedImprovement` constructor.
Expand All @@ -527,7 +542,14 @@ def construct_inputs_qNEI(
prune_baseline: If True, remove points in `X_baseline` that are
highly unlikely to be the best point. This can significantly
improve performance and is generally recommended.
kwargs: Not used.
constraints: A list of constraint callables which map a Tensor of posterior
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
are considered satisfied if the output is less than zero.
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. For more details, on this
parameter, see the docs of `compute_smoothed_constraint_indicator`.
ignored: Not used.
Returns:
A dict mapping kwarg names of the constructor to values.
Expand All @@ -553,6 +575,8 @@ def construct_inputs_qNEI(
"X_baseline": X_baseline,
"prune_baseline": prune_baseline,
"cache_root": cache_root,
"constraints": constraints,
"eta": eta,
}


Expand All @@ -566,7 +590,9 @@ def construct_inputs_qPI(
sampler: Optional[MCSampler] = None,
tau: float = 1e-3,
best_f: Optional[Union[float, Tensor]] = None,
**kwargs: Any,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
eta: Union[Tensor, float] = 1e-3,
**ignored: Any,
) -> Dict[str, Any]:
r"""Construct kwargs for the `qProbabilityOfImprovement` constructor.
Expand All @@ -588,13 +614,26 @@ def construct_inputs_qPI(
best_f: The best objective value observed so far (assumed noiseless). Can
be a `batch_shape`-shaped tensor, which in case of a batched model
specifies potentially different values for each element of the batch.
kwargs: Not used.
constraints: A list of constraint callables which map a Tensor of posterior
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
are considered satisfied if the output is less than zero.
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. For more details, on this
parameter, see the docs of `compute_smoothed_constraint_indicator`.
ignored: Not used.
Returns:
A dict mapping kwarg names of the constructor to values.
"""
if best_f is None:
best_f = get_best_f_mc(training_data=training_data, objective=objective)

best_f = get_best_f_mc(
training_data=training_data,
objective=objective,
posterior_transform=posterior_transform,
constraints=constraints,
model=model,
)
base_inputs = _construct_inputs_mc_base(
model=model,
objective=objective,
Expand All @@ -603,7 +642,13 @@ def construct_inputs_qPI(
X_pending=X_pending,
)

return {**base_inputs, "tau": tau, "best_f": best_f}
return {
**base_inputs,
"tau": tau,
"best_f": best_f,
"constraints": constraints,
"eta": eta,
}


@acqf_input_constructor(qUpperConfidenceBound)
Expand All @@ -615,7 +660,7 @@ def construct_inputs_qUCB(
X_pending: Optional[Tensor] = None,
sampler: Optional[MCSampler] = None,
beta: float = 0.2,
**kwargs: Any,
**ignored: Any,
) -> Dict[str, Any]:
r"""Construct kwargs for the `qUpperConfidenceBound` constructor.
Expand All @@ -631,7 +676,7 @@ def construct_inputs_qUCB(
sampler: The sampler used to draw base samples. If omitted, uses
the acquisition functions's default sampler.
beta: Controls tradeoff between mean and standard deviation in UCB.
kwargs: Not used.
ignored: Not used.
Returns:
A dict mapping kwarg names of the constructor to values.
Expand Down Expand Up @@ -1083,18 +1128,28 @@ def get_best_f_mc(
training_data: MaybeDict[SupervisedDataset],
objective: Optional[MCAcquisitionObjective] = None,
posterior_transform: Optional[PosteriorTransform] = None,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
model: Optional[Model] = None,
) -> Tensor:
if isinstance(training_data, dict) and not _field_is_shared(
training_data, fieldname="X"
):
raise NotImplementedError("Currently only block designs are supported.")

X_baseline = _get_dataset_field(
training_data,
fieldname="X",
transform=lambda field: field(),
assert_shared=True,
first_only=True,
)

Y = _get_dataset_field(
training_data,
fieldname="Y",
transform=lambda field: field(),
join_rule=lambda field_tensors: torch.cat(field_tensors, dim=-1),
)
) # batch_shape x n x d

if posterior_transform is not None:
# retain the original tensor dimension since objective expects explicit
Expand All @@ -1111,7 +1166,16 @@ def get_best_f_mc(
"acquisition functions)."
)
objective = IdentityMCObjective()
return objective(Y).max(-1).values
obj = objective(Y, X=X_baseline) # batch_shape x n
return compute_best_feasible_objective(
samples=Y,
obj=obj,
constraints=constraints,
model=model,
objective=objective,
posterior_transform=posterior_transform,
X_baseline=X_baseline,
)


def optimize_objective(
Expand Down
21 changes: 18 additions & 3 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,19 @@ def get_acquisition_function(
)
# instantiate and return the requested acquisition function
if acquisition_function_name in ("qEI", "qPI"):
obj = objective(
model.posterior(X_observed, posterior_transform=posterior_transform).mean
# Since these are the non-noisy variants, use the posterior mean at the observed
# inputs directly to compute the best feasible value without sampling.
Y = model.posterior(X_observed, posterior_transform=posterior_transform).mean
obj = objective(samples=Y, X=X_observed)
best_f = compute_best_feasible_objective(
samples=Y,
obj=obj,
constraints=constraints,
model=model,
objective=objective,
posterior_transform=posterior_transform,
X_baseline=X_observed,
)
best_f = obj.max(dim=-1).values
if acquisition_function_name == "qEI":
return monte_carlo.qExpectedImprovement(
model=model,
Expand All @@ -113,6 +122,8 @@ def get_acquisition_function(
objective=objective,
posterior_transform=posterior_transform,
X_pending=X_pending,
constraints=constraints,
eta=eta,
)
elif acquisition_function_name == "qPI":
return monte_carlo.qProbabilityOfImprovement(
Expand All @@ -123,6 +134,8 @@ def get_acquisition_function(
posterior_transform=posterior_transform,
X_pending=X_pending,
tau=kwargs.get("tau", 1e-3),
constraints=constraints,
eta=eta,
)
elif acquisition_function_name == "qNEI":
return monte_carlo.qNoisyExpectedImprovement(
Expand All @@ -135,6 +148,8 @@ def get_acquisition_function(
prune_baseline=kwargs.get("prune_baseline", True),
marginalize_dim=kwargs.get("marginalize_dim"),
cache_root=kwargs.get("cache_root", True),
constraints=constraints,
eta=eta,
)
elif acquisition_function_name == "qSR":
return monte_carlo.qSimpleRegret(
Expand Down
23 changes: 21 additions & 2 deletions test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,13 @@ def test_get_best_f_mc(self):
best_f = get_best_f_mc(training_data=self.blockX_multiY, objective=obj)

multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1)
best_f_expected = (multi_Y @ obj.weights).max()
best_f_expected = (multi_Y @ obj.weights).amax(dim=-1, keepdim=True)
self.assertEqual(best_f, best_f_expected)
post_tf = ScalarizedPosteriorTransform(weights=torch.ones(2))
best_f = get_best_f_mc(
training_data=self.blockX_multiY, posterior_transform=post_tf
)
best_f_expected = (multi_Y.sum(dim=-1)).max()
best_f_expected = (multi_Y.sum(dim=-1)).amax(dim=-1, keepdim=True)
self.assertEqual(best_f, best_f_expected)

@mock.patch("botorch.acquisition.input_constructors.optimize_acqf")
Expand Down Expand Up @@ -350,6 +350,9 @@ def test_construct_inputs_qEI(self):
self.assertIsNone(kwargs["objective"])
self.assertIsNone(kwargs["X_pending"])
self.assertIsNone(kwargs["sampler"])
self.assertIsNone(kwargs["constraints"])
self.assertIsInstance(kwargs["eta"], float)
self.assertTrue(kwargs["eta"] < 1)
X_pending = torch.rand(2, 2)
objective = LinearMCObjective(torch.rand(2))
kwargs = c(
Expand All @@ -362,6 +365,9 @@ def test_construct_inputs_qEI(self):
self.assertTrue(torch.equal(kwargs["objective"].weights, objective.weights))
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
self.assertIsNone(kwargs["sampler"])
self.assertIsNone(kwargs["constraints"])
self.assertIsInstance(kwargs["eta"], float)
self.assertTrue(kwargs["eta"] < 1)
multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1)
best_f_expected = objective(multi_Y).max()
self.assertEqual(kwargs["best_f"], best_f_expected)
Expand All @@ -386,6 +392,10 @@ def test_construct_inputs_qNEI(self):
self.assertIsNone(kwargs["sampler"])
self.assertTrue(kwargs["prune_baseline"])
self.assertTrue(torch.equal(kwargs["X_baseline"], self.blockX_blockY[0].X()))
self.assertIsNone(kwargs["constraints"])
self.assertIsInstance(kwargs["eta"], float)
self.assertTrue(kwargs["eta"] < 1)

with self.assertRaisesRegex(ValueError, "Field `X` must be shared"):
c(model=mock_model, training_data=self.multiX_multiY)
X_baseline = torch.rand(2, 2)
Expand All @@ -401,6 +411,9 @@ def test_construct_inputs_qNEI(self):
self.assertIsNone(kwargs["sampler"])
self.assertFalse(kwargs["prune_baseline"])
self.assertTrue(torch.equal(kwargs["X_baseline"], X_baseline))
self.assertIsNone(kwargs["constraints"])
self.assertIsInstance(kwargs["eta"], float)
self.assertTrue(kwargs["eta"] < 1)

def test_construct_inputs_qPI(self):
c = get_acqf_input_constructor(qProbabilityOfImprovement)
Expand All @@ -411,6 +424,9 @@ def test_construct_inputs_qPI(self):
self.assertIsNone(kwargs["X_pending"])
self.assertIsNone(kwargs["sampler"])
self.assertEqual(kwargs["tau"], 1e-3)
self.assertIsNone(kwargs["constraints"])
self.assertIsInstance(kwargs["eta"], float)
self.assertTrue(kwargs["eta"] < 1)
X_pending = torch.rand(2, 2)
objective = LinearMCObjective(torch.rand(2))
kwargs = c(
Expand All @@ -425,6 +441,9 @@ def test_construct_inputs_qPI(self):
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
self.assertIsNone(kwargs["sampler"])
self.assertEqual(kwargs["tau"], 1e-2)
self.assertIsNone(kwargs["constraints"])
self.assertIsInstance(kwargs["eta"], float)
self.assertTrue(kwargs["eta"] < 1)
multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1)
best_f_expected = objective(multi_Y).max()
self.assertEqual(kwargs["best_f"], best_f_expected)
Expand Down
Loading

0 comments on commit 8a7923a

Please sign in to comment.