Skip to content

Commit

Permalink
compute_smoothed_constraint_indicator -> `compute_smoothed_feasibil…
Browse files Browse the repository at this point in the history
…ity_indicator` (#1935)

Summary:
Pull Request resolved: #1935

D47365085 introduced the aptly named `compute_feasibility_indicator`. This commit brings the smoothed counterpart in alignment with the naming convention.

Reviewed By: Balandat

Differential Revision: D47436246

fbshipit-source-id: 5612b1cbb69993b475702779f62bb2cbfada9c96
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Jul 13, 2023
1 parent 8a7923a commit 538bbc3
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 19 deletions.
6 changes: 3 additions & 3 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def construct_inputs_qEI(
are considered satisfied if the output is less than zero.
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. For more details, on this
parameter, see the docs of `compute_smoothed_constraint_indicator`.
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
ignored: Not used.
Returns:
Expand Down Expand Up @@ -548,7 +548,7 @@ def construct_inputs_qNEI(
are considered satisfied if the output is less than zero.
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. For more details, on this
parameter, see the docs of `compute_smoothed_constraint_indicator`.
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
ignored: Not used.
Returns:
Expand Down Expand Up @@ -620,7 +620,7 @@ def construct_inputs_qPI(
are considered satisfied if the output is less than zero.
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. For more details, on this
parameter, see the docs of `compute_smoothed_constraint_indicator`.
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
ignored: Not used.
Returns:
Expand Down
12 changes: 6 additions & 6 deletions botorch/acquisition/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from botorch.exceptions.errors import UnsupportedError
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.utils.objective import compute_smoothed_constraint_indicator
from botorch.utils.objective import compute_smoothed_feasibility_indicator
from botorch.utils.transforms import (
concatenate_pending_points,
match_batch_shape,
Expand Down Expand Up @@ -215,7 +215,7 @@ def __init__(
acquistion utilities, e.g. all improvement-based acquisition functions.
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. For more details, on this
parameter, see the docs of `compute_smoothed_constraint_indicator`.
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
"""
if constraints is not None and isinstance(objective, ConstrainedMCObjective):
raise ValueError(
Expand Down Expand Up @@ -305,7 +305,7 @@ def _apply_constraints(self, acqval: Tensor, samples: Tensor) -> Tensor:
"Constraint-weighting requires unconstrained "
"acquisition values to be non-negative."
)
acqval = acqval * compute_smoothed_constraint_indicator(
acqval = acqval * compute_smoothed_feasibility_indicator(
constraints=self._constraints, samples=samples, eta=self._eta
)
return acqval
Expand Down Expand Up @@ -366,7 +366,7 @@ def __init__(
are considered satisfied if the output is less than zero.
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. For more details, on this
parameter, see the docs of `compute_smoothed_constraint_indicator`.
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
"""
super().__init__(
model=model,
Expand Down Expand Up @@ -457,7 +457,7 @@ def __init__(
are considered satisfied if the output is less than zero.
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. For more details, on this
parameter, see the docs of `compute_smoothed_constraint_indicator`.
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
TODO: similar to qNEHVI, when we are using sequential greedy candidate
selection, we could incorporate pending points X_baseline and compute
Expand Down Expand Up @@ -671,7 +671,7 @@ def __init__(
scalar is less than zero.
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. For more details, on this
parameter, see the docs of `compute_smoothed_constraint_indicator`.
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
"""
super().__init__(
model=model,
Expand Down
6 changes: 3 additions & 3 deletions botorch/acquisition/multi_objective/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from botorch.utils.multi_objective.box_decompositions.utils import (
_pad_batch_pareto_frontier,
)
from botorch.utils.objective import compute_smoothed_constraint_indicator
from botorch.utils.objective import compute_smoothed_feasibility_indicator
from botorch.utils.torch import BufferDict
from botorch.utils.transforms import (
concatenate_pending_points,
Expand Down Expand Up @@ -279,7 +279,7 @@ def _compute_qehvi(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
obj = self.objective(samples, X=X)
q = obj.shape[-2]
if self.constraints is not None:
feas_weights = compute_smoothed_constraint_indicator(
feas_weights = compute_smoothed_feasibility_indicator(
constraints=self.constraints, samples=samples, eta=self.eta
) # `sample_shape x batch-shape x q`
self._cache_q_subset_indices(q_out=q)
Expand Down Expand Up @@ -414,7 +414,7 @@ def __init__(
tensor the length of the tensor must match the number of provided
constraints. The i-th constraint is then estimated with the i-th
eta value. For more details, on this parameter, see the docs of
`compute_smoothed_constraint_indicator`.
`compute_smoothed_feasibility_indicator`.
prune_baseline: If True, remove points in `X_baseline` that are
highly unlikely to be the pareto optimal and better than the
reference point. This can significantly improve computation time and
Expand Down
4 changes: 2 additions & 2 deletions botorch/utils/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def apply_constraints_nonnegative_soft(
Returns:
A `n_samples x b x q (x m')`-dim tensor of feasibility-weighted objectives.
"""
w = compute_smoothed_constraint_indicator(
w = compute_smoothed_feasibility_indicator(
constraints=constraints, samples=samples, eta=eta
)
if obj.dim() == samples.dim():
Expand Down Expand Up @@ -116,7 +116,7 @@ def compute_feasibility_indicator(
return ind


def compute_smoothed_constraint_indicator(
def compute_smoothed_feasibility_indicator(
constraints: List[Callable[[Tensor], Tensor]],
samples: Tensor,
eta: Union[Tensor, float],
Expand Down
10 changes: 5 additions & 5 deletions test/utils/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from botorch.utils import apply_constraints, get_objective_weights_transform
from botorch.utils.objective import (
compute_feasibility_indicator,
compute_smoothed_constraint_indicator,
compute_smoothed_feasibility_indicator,
)
from botorch.utils.testing import BotorchTestCase
from torch import Tensor
Expand Down Expand Up @@ -196,14 +196,14 @@ def test_constraint_indicators(self):
self.assertAllClose(ind, torch.zeros_like(ind))
self.assertEqual(ind.dtype, torch.bool)

smoothed_ind = compute_smoothed_constraint_indicator(
smoothed_ind = compute_smoothed_feasibility_indicator(
constraints=[zeros_f], samples=samples, eta=1e-3
)
self.assertAllClose(smoothed_ind, ones_f(samples) / 2)

# two constraints
samples = torch.randn(1)
smoothed_ind = compute_smoothed_constraint_indicator(
smoothed_ind = compute_smoothed_feasibility_indicator(
constraints=[zeros_f, zeros_f],
samples=samples,
eta=1e-3,
Expand All @@ -218,13 +218,13 @@ def test_constraint_indicators(self):
)
self.assertAllClose(ind, torch.ones_like(ind))

smoothed_ind = compute_smoothed_constraint_indicator(
smoothed_ind = compute_smoothed_feasibility_indicator(
constraints=[minus_one_f], samples=samples, eta=1e-3
)
self.assertTrue((smoothed_ind > 3 / 4).all())

with self.assertRaisesRegex(ValueError, "Number of provided constraints"):
compute_smoothed_constraint_indicator(
compute_smoothed_feasibility_indicator(
constraints=[zeros_f, zeros_f],
samples=samples,
eta=torch.tensor([0.1], device=self.device),
Expand Down

0 comments on commit 538bbc3

Please sign in to comment.