diff --git a/botorch/generation/gen.py b/botorch/generation/gen.py index e192030f1f..ebfbc573ca 100644 --- a/botorch/generation/gen.py +++ b/botorch/generation/gen.py @@ -114,15 +114,12 @@ def gen_candidates_scipy( # if there are fixed features we may optimize over a domain of lower dimension reduced_domain = False if fixed_features: - # TODO: We can support fixed features, see Max's comment on D33551393. We can - # consider adding this at a later point. - if nonlinear_inequality_constraints: - raise NotImplementedError( - "Fixed features are not supported when non-linear inequality " - "constraints are given." - ) - # if there are no constraints things are straightforward - if not (inequality_constraints or equality_constraints): + # if there are no constraints, things are straightforward + if not ( + inequality_constraints + or equality_constraints + or nonlinear_inequality_constraints + ): reduced_domain = True # if there are we need to make sure features are fixed to specific values else: @@ -137,6 +134,7 @@ def gen_candidates_scipy( upper_bounds=upper_bounds, inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, + nonlinear_inequality_constraints=nonlinear_inequality_constraints, ) # call the routine with no fixed_features clamped_candidates, batch_acquisition = gen_candidates_scipy( @@ -146,6 +144,7 @@ def gen_candidates_scipy( upper_bounds=_no_fixed_features.upper_bounds, inequality_constraints=_no_fixed_features.inequality_constraints, equality_constraints=_no_fixed_features.equality_constraints, + nonlinear_inequality_constraints=_no_fixed_features.nonlinear_inequality_constraints, # noqa: E501 options=options, fixed_features=None, timeout_sec=timeout_sec, @@ -342,6 +341,7 @@ def gen_candidates_torch( upper_bounds=upper_bounds, inequality_constraints=None, equality_constraints=None, + nonlinear_inequality_constraints=None, ) # call the routine with no fixed_features diff --git a/botorch/generation/utils.py b/botorch/generation/utils.py index 181feffc94..f6cc395dd6 100644 --- a/botorch/generation/utils.py +++ b/botorch/generation/utils.py @@ -7,11 +7,14 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from botorch.acquisition import AcquisitionFunction, FixedFeatureAcquisitionFunction -from botorch.optim.parameter_constraints import _generate_unfixed_lin_constraints +from botorch.optim.parameter_constraints import ( + _generate_unfixed_lin_constraints, + _generate_unfixed_nonlin_constraints, +) from torch import Tensor @@ -63,6 +66,7 @@ class _NoFixedFeatures: upper_bounds: Optional[Union[float, Tensor]] inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] + nonlinear_inequality_constraints: Optional[List[Callable[[Tensor], Tensor]]] def _remove_fixed_features_from_optimization( @@ -73,6 +77,7 @@ def _remove_fixed_features_from_optimization( upper_bounds: Optional[Union[float, Tensor]], inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]], equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]], + nonlinear_inequality_constraints: Optional[List[Callable[[Tensor], Tensor]]], ) -> _NoFixedFeatures: """ Given a set of non-empty fixed features, this function effectively reduces the @@ -98,6 +103,11 @@ def _remove_fixed_features_from_optimization( equality constraints: A list of tuples (indices, coefficients, rhs), with each tuple encoding an inequality constraint of the form `sum_i (X[indices[i]] * coefficients[i]) = rhs`. + nonlinear_inequality_constraints: A list of callables with that represent + non-linear inequality constraints of the form `callable(x) >= 0`. Each + callable is expected to take a `(num_restarts) x q x d`-dim tensor as + an input and return a `(num_restarts) x q`-dim tensor with the + constraint values. Returns: _NoFixedFeatures dataclass object. @@ -140,6 +150,11 @@ def _remove_fixed_features_from_optimization( dimension=d, eq=True, ) + nonlinear_inequality_constraints = _generate_unfixed_nonlin_constraints( + constraints=nonlinear_inequality_constraints, + fixed_features=fixed_features, + dimension=d, + ) return _NoFixedFeatures( acquisition_function=acquisition_function, initial_conditions=initial_conditions, @@ -147,4 +162,5 @@ def _remove_fixed_features_from_optimization( upper_bounds=upper_bounds, inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, + nonlinear_inequality_constraints=nonlinear_inequality_constraints, ) diff --git a/botorch/optim/parameter_constraints.py b/botorch/optim/parameter_constraints.py index 55fce9ffea..1a6829803f 100644 --- a/botorch/optim/parameter_constraints.py +++ b/botorch/optim/parameter_constraints.py @@ -312,13 +312,51 @@ def _make_linear_constraints( return constraints +def _generate_unfixed_nonlin_constraints( + constraints: Optional[List[Callable[[Tensor], Tensor]]], + fixed_features: Dict[int, float], + dimension: int, +) -> Optional[List[Callable[[Tensor], Tensor]]]: + """Given a dictionary of fixed features, returns a list of callables for + nonlinear inequality constraints expecting only a tensor with the non-fixed + features as input. + """ + if not constraints: + return constraints + + selector = [] + idx_X, idx_f = 0, dimension - len(fixed_features) + for i in range(dimension): + if i in fixed_features.keys(): + selector.append(idx_f) + idx_f += 1 + else: + selector.append(idx_X) + idx_X += 1 + + values = torch.tensor(list(fixed_features.values()), dtype=torch.double) + + def _wrap_nonlin_constraint( + constraint: Callable[[Tensor], Tensor] + ) -> Callable[[Tensor], Tensor]: + def new_nonlin_constraint(X: Tensor) -> Tensor: + ivalues = values.to(X).expand(*X.shape[:-1], len(fixed_features)) + X_perm = torch.cat([X, ivalues], dim=-1) + return constraint(X_perm[..., selector]) + + return new_nonlin_constraint + + return [ + _wrap_nonlin_constraint(constraint=constraint) for constraint in constraints + ] + + def _generate_unfixed_lin_constraints( constraints: Optional[List[Tuple[Tensor, Tensor, float]]], fixed_features: Dict[int, float], dimension: int, eq: bool, ) -> Optional[List[Tuple[Tensor, Tensor, float]]]: - # If constraints is None or an empty list, then return itself if not constraints: return constraints diff --git a/test/generation/test_utils.py b/test/generation/test_utils.py index 209e76a566..f17d2dcbf0 100644 --- a/test/generation/test_utils.py +++ b/test/generation/test_utils.py @@ -86,16 +86,36 @@ def check_cons(old_cons, new_cons): else: self.assertEqual(old_cons, new_cons) + def check_nlc(old_nlcs, new_nlcs): + complete_data = torch.tensor( + [[4.0, 1.0, 2.0, -1.0, 3.0]], device=self.device + ) + reduced_data = torch.tensor([[4.0, 2.0, 3.0]], device=self.device) + if old_nlcs: + self.assertAllClose( + old_nlcs[0](complete_data), + new_nlcs[0](reduced_data), + ) + else: + self.assertEqual(old_nlcs, new_nlcs) + + def nlc(x): + return x[..., 2] + + old_nlcs = [nlc] + for ( lower_bounds, upper_bounds, inequality_constraints, equality_constraints, + nonlinear_inequality_constraints, ) in product( [None, -1.0, tensor_lower_bounds], [None, 1.0, tensor_upper_bounds], [None, old_inequality_constraints], [None, old_equality_constraints], + [None, old_nlcs], ): _no_ff = _remove_fixed_features_from_optimization( fixed_features=fixed_features, @@ -105,6 +125,7 @@ def check_cons(old_cons, new_cons): upper_bounds=upper_bounds, inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, + nonlinear_inequality_constraints=nonlinear_inequality_constraints, ) self.assertIsInstance( _no_ff.acquisition_function, FixedFeatureAcquisitionFunction @@ -114,3 +135,7 @@ def check_cons(old_cons, new_cons): check_bounds_and_init(upper_bounds, _no_ff.upper_bounds) check_cons(inequality_constraints, _no_ff.inequality_constraints) check_cons(equality_constraints, _no_ff.equality_constraints) + check_nlc( + nonlinear_inequality_constraints, + _no_ff.nonlinear_inequality_constraints, + ) diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index f0b9ce5ddc..d29b63e05c 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -862,6 +862,32 @@ def nlc4(x): torch.allclose(acq_value, torch.tensor(2.45, **tkwargs), atol=1e-3) ) + with torch.random.fork_rng(): + torch.manual_seed(0) + batch_initial_conditions = torch.rand(num_restarts, 1, 3, **tkwargs) + batch_initial_conditions[..., 0] = 2 + + # test with fixed features + candidates, acq_value = optimize_acqf( + acq_function=mock_acq_function, + bounds=bounds, + q=1, + nonlinear_inequality_constraints=[nlc1, nlc2], + batch_initial_conditions=batch_initial_conditions, + num_restarts=num_restarts, + fixed_features={0: 2}, + ) + self.assertEqual(candidates[0, 0], 2.0) + self.assertTrue( + torch.allclose( + torch.sort(candidates).values, + torch.tensor([[0, 2, 2]], **tkwargs), + ) + ) + self.assertTrue( + torch.allclose(acq_value, torch.tensor(2.8284, **tkwargs), atol=1e-3) + ) + # Test that an ic_generator object with the same API as # gen_batch_initial_conditions returns candidates of the # required shape. @@ -879,22 +905,6 @@ def nlc4(x): ) self.assertEqual(candidates.size(), torch.Size([1, 3])) - # Make sure fixed features aren't supported - with self.assertRaisesRegex( - NotImplementedError, - "Fixed features are not supported when non-linear inequality " - "constraints are given.", - ): - optimize_acqf( - acq_function=mock_acq_function, - bounds=bounds, - q=1, - nonlinear_inequality_constraints=[nlc1, nlc2, nlc3, nlc4], - batch_initial_conditions=batch_initial_conditions, - num_restarts=num_restarts, - fixed_features={0: 0.1}, - ) - # Constraints must be passed in as lists with self.assertRaisesRegex( ValueError, diff --git a/test/optim/test_parameter_constraints.py b/test/optim/test_parameter_constraints.py index 748ee950cd..ee88cc7508 100644 --- a/test/optim/test_parameter_constraints.py +++ b/test/optim/test_parameter_constraints.py @@ -12,6 +12,7 @@ from botorch.optim.parameter_constraints import ( _arrayify, _generate_unfixed_lin_constraints, + _generate_unfixed_nonlin_constraints, _make_linear_constraints, eval_lin_constraint, lin_constraint_jac, @@ -215,6 +216,42 @@ def test_make_scipy_linear_constraints_unsupported(self): equality_constraints=[(indices, coefficients, 1.0)], ) + def test_generate_unfixed_nonlin_constraints(self): + def nlc1(x): + return 4 - x.sum(dim=-1) + + def nlc2(x): + return x[..., 0] - 1 + + # first test with one constraint + (new_nlc1,) = _generate_unfixed_nonlin_constraints( + constraints=[nlc1], fixed_features={1: 2.0}, dimension=3 + ) + self.assertAllClose( + nlc1(torch.tensor([[4.0, 2.0, 2.0]], device=self.device)), + new_nlc1(torch.tensor([[4.0, 2.0]], device=self.device)), + ) + # test with several constraints + constraints = [nlc1, nlc2] + new_constraints = _generate_unfixed_nonlin_constraints( + constraints=constraints, fixed_features={1: 2.0}, dimension=3 + ) + for nlc, new_nlc in zip(constraints, new_constraints): + self.assertAllClose( + nlc(torch.tensor([[4.0, 2.0, 2.0]], device=self.device)), + new_nlc(torch.tensor([[4.0, 2.0]], device=self.device)), + ) + # test with several constraints and two fixes + constraints = [nlc1, nlc2] + new_constraints = _generate_unfixed_nonlin_constraints( + constraints=constraints, fixed_features={1: 2.0, 2: 1.0}, dimension=3 + ) + for nlc, new_nlc in zip(constraints, new_constraints): + self.assertAllClose( + nlc(torch.tensor([[4.0, 2.0, 1.0]], device=self.device)), + new_nlc(torch.tensor([[4.0]], device=self.device)), + ) + def test_generate_unfixed_lin_constraints(self): # Case 1: some fixed features are in the indices indices = [