diff --git a/botorch/acquisition/penalized.py b/botorch/acquisition/penalized.py index b114362ea9..6b929a8c96 100644 --- a/botorch/acquisition/penalized.py +++ b/botorch/acquisition/penalized.py @@ -11,7 +11,7 @@ from __future__ import annotations import math -from typing import Callable, List, Optional +from typing import Any, Callable, List, Optional import torch from botorch.acquisition.acquisition import AcquisitionFunction @@ -139,6 +139,43 @@ def forward(self, X: Tensor) -> Tensor: return regularization_term +def narrow_gaussian(X: Tensor, a: Tensor) -> Tensor: + return torch.exp(-0.5 * (X / a) ** 2) + + +def nnz_approx(X: Tensor, target_point: Tensor, a: Tensor) -> Tensor: + r"""Differentiable relaxation of ||X - target_point||_0 + + Args: + X: An `n x d` tensor of inputs. + target_point: A tensor of size `n` corresponding to the target point. + a: A scalar tensor that controls the differentiable relaxation. + """ + d = X.shape[-1] + if d != target_point.shape[-1]: + raise ValueError("X and target_point have different shapes.") + return d - narrow_gaussian(X - target_point, a).sum(dim=-1, keepdim=True) + + +class L0Approximation(torch.nn.Module): + r"""Differentiable relaxation of the L0 norm using a Gaussian basis function.""" + + def __init__(self, target_point: Tensor, a: float = 1.0, **tkwargs: Any) -> None: + r"""Initializing L0 penalty with differentiable relaxation. + + Args: + target_point: A tensor corresponding to the target point. + a: A hyperparameter that controls the differentiable relaxation. + """ + super().__init__() + self.target_point = target_point + # hyperparameter to control the differentiable relaxation in L0 norm function. + self.register_buffer("a", torch.tensor(a, **tkwargs)) + + def __call__(self, X: Tensor) -> Tensor: + return nnz_approx(X=X, target_point=self.target_point, a=self.a) + + class PenalizedAcquisitionFunction(AcquisitionFunction): r"""Single-outcome acquisition function regularized by the given penalty. diff --git a/botorch/optim/__init__.py b/botorch/optim/__init__.py index 540752d1e0..a9d9619469 100644 --- a/botorch/optim/__init__.py +++ b/botorch/optim/__init__.py @@ -15,6 +15,13 @@ scipy_minimize, torch_minimize, ) +from botorch.optim.homotopy import ( + FixedHomotopySchedule, + Homotopy, + HomotopyParameter, + LinearHomotopySchedule, + LogLinearHomotopySchedule, +) from botorch.optim.initializers import initialize_q_batch, initialize_q_batch_nonneg from botorch.optim.numpy_converter import module_to_array, set_params_with_array from botorch.optim.optimize import ( @@ -25,6 +32,7 @@ optimize_acqf_discrete_local_search, optimize_acqf_mixed, ) +from botorch.optim.optimize_homotopy import optimize_acqf_homotopy from botorch.optim.stopping import ExpMAStoppingCriterion @@ -42,9 +50,15 @@ "optimize_acqf_discrete", "optimize_acqf_discrete_local_search", "optimize_acqf_mixed", + "optimize_acqf_homotopy", "module_to_array", "scipy_minimize", "set_params_with_array", "torch_minimize", "ExpMAStoppingCriterion", + "FixedHomotopySchedule", + "Homotopy", + "HomotopyParameter", + "LinearHomotopySchedule", + "LogLinearHomotopySchedule", ] diff --git a/botorch/optim/homotopy.py b/botorch/optim/homotopy.py new file mode 100644 index 0000000000..c6b5534f50 --- /dev/null +++ b/botorch/optim/homotopy.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import math +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, List, Optional, Union + +import torch +from torch import Tensor +from torch.nn import Parameter + + +class HomotopySchedule(ABC): + @property + @abstractmethod + def num_steps(self) -> int: + """Number of steps in the schedule.""" + pass + + @property + @abstractmethod + def value(self) -> Any: + """Current value in the schedule.""" + pass + + @property + @abstractmethod + def should_stop(self) -> bool: + """Return true if we have incremented past the end of the schedule.""" + pass + + @abstractmethod + def restart(self) -> None: + """Restart the schedule to start from the beginning.""" + pass + + @abstractmethod + def step(self) -> None: + """Move to solving the next problem.""" + pass + + +class FixedHomotopySchedule(HomotopySchedule): + """Homotopy schedule with a fixed list of values.""" + + def __init__(self, values: List[Any]) -> None: + r"""Initialize FixedHomotopySchedule. + + Args: + values: A list of values used in homotopy + """ + self._values = values + self.idx = 0 + + @property + def num_steps(self) -> int: + return len(self._values) + + @property + def value(self) -> Any: + return self._values[self.idx] + + @property + def should_stop(self) -> bool: + return self.idx == len(self._values) + + def restart(self) -> None: + self.idx = 0 + + def step(self) -> None: + self.idx += 1 + + +class LinearHomotopySchedule(FixedHomotopySchedule): + """Linear homotopy schedule.""" + + def __init__(self, start: float, end: float, num_steps: int) -> None: + r"""Initialize LinearHomotopySchedule. + + Args: + start: start value of homotopy + end: end value of homotopy + num_steps: number of steps in the homotopy schedule. + """ + super().__init__( + values=torch.linspace(start, end, num_steps, dtype=torch.double).tolist() + ) + + +class LogLinearHomotopySchedule(FixedHomotopySchedule): + """Log-linear homotopy schedule.""" + + def __init__(self, start: float, end: float, num_steps: int): + r"""Initialize LogLinearHomotopySchedule. + + Args: + start: start value of homotopy + end: end value of homotopy + num_steps: number of steps in the homotopy schedule. + """ + super().__init__( + values=torch.logspace( + math.log10(start), math.log10(end), num_steps, dtype=torch.double + ).tolist() + ) + + +@dataclass +class HomotopyParameter: + r"""Homotopy parameter. + + The parameter is expected to either be a torch parameter or a torch tensor which may + correspond to a buffer of a module. The parameter has a corresponding schedule. + """ + parameter: Union[Parameter, Tensor] + schedule: HomotopySchedule + + +class Homotopy: + """Generic homotopy class. + + This class is designed to be used in `optimize_acqf_homotopy`. Given a set of + homotopy parameters and corresponding schedules we step through the homotopies + until we have solved the final problem. We additionally support passing in a list + of callbacks that will be executed each time `step`, `reset`, and `restart` are + called. + """ + + def __init__( + self, + homotopy_parameters: List[HomotopyParameter], + callbacks: Optional[List[Callable]] = None, + ) -> None: + r"""Initialize the homotopy. + + Args: + homotopy_parameters: List of homotopy parameters + callbacks: Optional list of callbacks that are executed each time + `restart`, `reset`, or `step` are called. These may be used to, e.g., + reinitialize the acquisition function which is needed when using qNEHVI. + """ + # TODO: Check inputs + self._homotopy_parameters = homotopy_parameters + self._callbacks = callbacks or [] + self._original_values = [ + hp.parameter.item() for hp in self._homotopy_parameters + ] + assert all( + isinstance(hp.parameter, Parameter) or isinstance(hp.parameter, Tensor) + for hp in self._homotopy_parameters + ) + # Assume the same number of steps for now + assert len({h.schedule.num_steps for h in self._homotopy_parameters}) == 1 + # Initialize the homotopy parameters + self.restart() + + def _execute_callbacks(self) -> None: + """Execute the callbacks.""" + for callback in self._callbacks: + callback() + + @property + def should_stop(self) -> bool: + """Returns true if all schedules have reached the end.""" + return all(h.schedule.should_stop for h in self._homotopy_parameters) + + def restart(self) -> None: + """Restart the homotopy to use the initial value in the schedule.""" + for hp in self._homotopy_parameters: + hp.schedule.restart() + hp.parameter.data.fill_(hp.schedule.value) + self._execute_callbacks() + + def reset(self) -> None: + """Reset the homotopy parameter to their original values.""" + for hp, val in zip(self._homotopy_parameters, self._original_values): + hp.parameter.data.fill_(val) + self._execute_callbacks() + + def step(self) -> None: + """Take a step according to the schedules.""" + for hp in self._homotopy_parameters: + hp.schedule.step() + if not hp.schedule.should_stop: + hp.parameter.data.fill_(hp.schedule.value) + self._execute_callbacks() diff --git a/botorch/optim/optimize_homotopy.py b/botorch/optim/optimize_homotopy.py new file mode 100644 index 0000000000..e374d6f872 --- /dev/null +++ b/botorch/optim/optimize_homotopy.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Dict, Optional, Tuple, Union + +import torch +from botorch.acquisition import AcquisitionFunction +from botorch.optim.homotopy import Homotopy +from botorch.optim.optimize import optimize_acqf +from torch import Tensor + + +def prune_candidates( + candidates: Tensor, acq_values: Tensor, prune_tolerance: float +) -> Tensor: + r"""Prune candidates based on their distance to other candidates. + + Args: + candidates: An `n x d` tensor of candidates. + acq_values: An `n` tensor of candidate values. + prune_tolerance: The minimum distance to prune candidates. + + Returns: + An `m x d` tensor of pruned candidates. + """ + if candidates.ndim != 2: + raise ValueError("`candidates` must be of size `n x d`.") + if acq_values.ndim != 1 or len(acq_values) != candidates.shape[0]: + raise ValueError("`acq_values` must be of size `n`.") + if prune_tolerance < 0: + raise ValueError("`prune_tolerance` must be >= 0.") + sorted_inds = acq_values.argsort(descending=True) + candidates = candidates[sorted_inds] + + candidates_new = candidates[:1, :] + for i in range(1, candidates.shape[0]): + if ( + torch.cdist(candidates[i : i + 1, :], candidates_new).min() + > prune_tolerance + ): + candidates_new = torch.cat( + [candidates_new, candidates[i : i + 1, :]], dim=-2 + ) + return candidates_new + + +def optimize_acqf_homotopy( + acq_function: AcquisitionFunction, + bounds: Tensor, + q: int, + homotopy: Homotopy, + num_restarts: int, + raw_samples: Optional[int] = None, + fixed_features: Optional[Dict[int, float]] = None, + options: Optional[Dict[str, Union[bool, float, int, str]]] = None, + final_options: Optional[Dict[str, Union[bool, float, int, str]]] = None, + batch_initial_conditions: Optional[Tensor] = None, + post_processing_func: Optional[Callable[[Tensor], Tensor]] = None, + prune_tolerance: float = 1e-4, +) -> Tuple[Tensor, Tensor]: + r"""Generate a set of candidates via multi-start optimization. + + TODO: Merge with `optimize_acqf` before moving to OSS. + + Args: + acq_function: An AcquisitionFunction. + bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`. + q: The number of candidates. + homotopy: Homotopy object that will make the necessary modifications to the + problem when calling `step()`. + num_restarts: The number of starting points for multistart acquisition + function optimization. + raw_samples: The number of samples for initialization. This is required + if `batch_initial_conditions` is not specified. + fixed_features: A map `{feature_index: value}` for features that + should be fixed to a particular value during generation. + options: Options for candidate generation. + final_options: Options for candidate generation in the last homotopy step. + batch_initial_conditions: A tensor to specify the initial conditions. Set + this if you do not want to use default initialization strategy. + post_processing_func: Post processing function (such as roundingor clamping) + that is applied before choosing the final candidate. + """ + candidate_list, acq_value_list = [], [] + if q > 1: + base_X_pending = acq_function.X_pending + + # TODO: Another option would be to have the homotopy in the outer loop and + # optimize over q in the inner loop. May be interesting to benchmark. + for _ in range(q): + # TODO: Do we want to generate new initial conditions after generating + # the first candidate? + candidates = batch_initial_conditions + homotopy.restart() + + # TODO: Do we want to allow using a decreasing number of initial conditions? + # It may be advantageous to start with a large number and prune the least + # promising candidates as we step through the homotopies. + while not homotopy.should_stop: + candidates, acq_values = optimize_acqf( + q=1, + acq_function=acq_function, + bounds=bounds, + num_restarts=num_restarts, + batch_initial_conditions=candidates, + raw_samples=raw_samples, + fixed_features=fixed_features, + return_best_only=False, + options=options, + ) + homotopy.step() + + # Prune candidates + candidates = prune_candidates( + candidates=candidates.squeeze(1), + acq_values=acq_values, + prune_tolerance=prune_tolerance, + ).unsqueeze(1) + + # Optimize one more time with the final options + candidates, acq_values = optimize_acqf( + q=1, + acq_function=acq_function, + bounds=bounds, + num_restarts=num_restarts, + batch_initial_conditions=candidates, + return_best_only=False, + options=final_options, + ) + + # Post-process the candidates and grab the best candidate + if post_processing_func is not None: + candidates = post_processing_func(candidates) + acq_values = acq_function(candidates) + best = torch.argmax(acq_values.view(-1), dim=0) + candidate, acq_value = candidates[best], acq_values[best] + + # Keep the new candidate and update the pending points + candidate_list.append(candidate) + acq_value_list.append(acq_value) + selected_candidates = torch.cat(candidate_list, dim=-2) + if q > 1: + acq_function.set_X_pending( + torch.cat([base_X_pending, selected_candidates], dim=-2) + if base_X_pending is not None + else selected_candidates + ) + + if q > 1: # Reset acq_function to previous X_pending state + acq_function.set_X_pending(base_X_pending) + homotopy.reset() # Reset the homotopy parameters + + return selected_candidates, torch.stack(acq_value_list) diff --git a/test/acquisition/test_penalized.py b/test/acquisition/test_penalized.py index 818f51a089..67578ef975 100644 --- a/test/acquisition/test_penalized.py +++ b/test/acquisition/test_penalized.py @@ -11,6 +11,7 @@ GaussianPenalty, group_lasso_regularizer, GroupLassoPenalty, + L0Approximation, L1Penalty, L1PenaltyObjective, L2Penalty, @@ -104,6 +105,53 @@ def test_group_lasso_penalty(self): group_lasso_module(sample_point_2) +class TestL0Approximation(BotorchTestCase): + def test_L0Approximation(self): + for dtype in (torch.float, torch.double): + tkwargs = {"device": self.device, "dtype": dtype} + target_point = torch.zeros(2, **tkwargs) + + # test init + l0 = L0Approximation(target_point=target_point, **tkwargs) + self.assertTrue(torch.equal(l0.target_point, target_point)) + self.assertAllClose(l0.a.data, torch.tensor(1.0, **tkwargs)) + + # verify L0 norm + self.assertTrue( + torch.equal( + l0(torch.zeros(2, **tkwargs)).data, torch.tensor([0], **tkwargs) + ) + ) + # check two-dim input tensors X + self.assertTrue( + torch.equal( + l0(torch.zeros(3, 2, **tkwargs)).data, torch.zeros(3, 1, **tkwargs) + ) + ) + + # test raise when X and target_point have mismatched shape + with self.assertRaises(ValueError): + l0(torch.zeros(3, **tkwargs)) + + # test init with different a + l0 = L0Approximation(target_point=target_point, a=2.0, **tkwargs) + self.assertAllClose(l0.a.data, torch.tensor(2.0, **tkwargs)) + self.assertAllClose( + l0(torch.ones(2, **tkwargs)).data, + torch.tensor([0.2350], **tkwargs), + rtol=1e-04, + ) + + # reset a + l0.a.data.fill_(0.5) + self.assertTrue(torch.equal(l0.a.data, torch.tensor(0.5, **tkwargs))) + self.assertAllClose( + l0(torch.ones(2, **tkwargs)).data, + torch.tensor([1.7293], **tkwargs), + rtol=1e-04, + ) + + class TestPenalizedAcquisitionFunction(BotorchTestCase): def test_penalized_acquisition_function(self): for dtype in (torch.float, torch.double): diff --git a/test/optim/test_homotopy.py b/test/optim/test_homotopy.py new file mode 100644 index 0000000000..cebcb68ea3 --- /dev/null +++ b/test/optim/test_homotopy.py @@ -0,0 +1,208 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest.mock as mock + +import torch +from botorch.acquisition import PosteriorMean +from botorch.models import GenericDeterministicModel +from botorch.optim.homotopy import ( + FixedHomotopySchedule, + Homotopy, + HomotopyParameter, + LinearHomotopySchedule, + LogLinearHomotopySchedule, +) +from botorch.optim.optimize_homotopy import optimize_acqf_homotopy, prune_candidates +from botorch.utils.testing import BotorchTestCase +from torch.nn import Parameter + + +PRUNE_CANDIDATES_PATH = f"{prune_candidates.__module__}" + + +class TestHomotopy(BotorchTestCase): + def _test_schedule(self, schedule, values): + self.assertEqual(schedule.num_steps, len(values)) + self.assertEqual(schedule.value, values[0]) + self.assertFalse(schedule.should_stop) + for i in range(len(values) - 1): + schedule.step() + self.assertEqual(schedule.value, values[i + 1]) + self.assertFalse(schedule.should_stop) + schedule.step() + self.assertTrue(schedule.should_stop) + schedule.restart() + self.assertEqual(schedule.value, values[0]) + self.assertFalse(schedule.should_stop) + + def test_fixed_schedule(self): + values = [1, 3, 7] + fixed = FixedHomotopySchedule(values=values) + self.assertEqual(fixed._values, values) + self._test_schedule(schedule=fixed, values=values) + + def test_linear_schedule(self): + values = [1, 2, 3, 4, 5] + linear = LinearHomotopySchedule(start=1, end=5, num_steps=5) + self.assertEqual(linear._values, values) + self._test_schedule(schedule=linear, values=values) + + def test_log_linear_schedule(self): + values = [0.01, 0.1, 1, 10, 100] + linear = LogLinearHomotopySchedule(start=0.01, end=100, num_steps=5) + self.assertEqual(linear._values, values) + self._test_schedule(schedule=linear, values=values) + + def test_homotopy(self): + tkwargs = {"device": self.device, "dtype": torch.double} + p1 = Parameter(-2 * torch.ones(1, **tkwargs)) + v1 = [1, 2, 3, 4, 5] + p2 = -3 * torch.ones(1, **tkwargs) + v2 = [0.01, 0.1, 1, 10, 100] + callback = mock.Mock() + homotopy_parameters = [ + HomotopyParameter( + parameter=p1, + schedule=LinearHomotopySchedule(start=1, end=5, num_steps=5), + ), + HomotopyParameter( + parameter=p2, + schedule=LogLinearHomotopySchedule(start=0.01, end=100, num_steps=5), + ), + ] + homotopy = Homotopy( + homotopy_parameters=homotopy_parameters, callbacks=[callback] + ) + self.assertEqual(homotopy._original_values, [-2, -3]) + self.assertEqual(homotopy._homotopy_parameters, homotopy_parameters) + self.assertEqual(homotopy._callbacks, [callback]) + self.assertEqual( + [h.parameter.item() for h in homotopy._homotopy_parameters], [v1[0], v2[0]] + ) + for i in range(4): + homotopy.step() + self.assertEqual( + [h.parameter.item() for h in homotopy._homotopy_parameters], + [v1[i + 1], v2[i + 1]], + ) + self.assertFalse(homotopy.should_stop) + homotopy.step() + self.assertTrue(homotopy.should_stop) + # Restart the schedules + homotopy.restart() + self.assertEqual( + [h.parameter.item() for h in homotopy._homotopy_parameters], [v1[0], v2[0]] + ) + # Reset the parameters to their original values + homotopy.reset() + self.assertEqual( + [h.parameter.item() for h in homotopy._homotopy_parameters], [-2, -3] + ) + # Expect the call count to be 8: init (1), step (5), restart (1), reset (1). + self.assertEqual(callback.call_count, 8) + + def test_optimize_acqf_homotopy(self): + tkwargs = {"device": self.device, "dtype": torch.double} + p = Parameter(-2 * torch.ones(1, **tkwargs)) + hp = HomotopyParameter( + parameter=p, + schedule=LinearHomotopySchedule(start=4, end=0, num_steps=5), + ) + model = GenericDeterministicModel(f=lambda x: 5 - (x - p) ** 2) + acqf = PosteriorMean(model=model) + candidate, acqf_val = optimize_acqf_homotopy( + q=1, + acq_function=acqf, + bounds=torch.tensor([[-10], [5]]).to(**tkwargs), + homotopy=Homotopy(homotopy_parameters=[hp]), + num_restarts=2, + raw_samples=16, + post_processing_func=lambda x: x.round(), + ) + self.assertEqual(candidate, torch.zeros(1, **tkwargs)) + self.assertEqual(acqf_val, 5 * torch.ones(1, **tkwargs)) + + # test fixed feature + fixed_features = {0: 1.0} + model = GenericDeterministicModel( + f=lambda x: 5 - (x - p).sum(dim=-1, keepdims=True) ** 2 + ) + acqf = PosteriorMean(model=model) + candidate, acqf_val = optimize_acqf_homotopy( + q=1, + acq_function=acqf, + bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), + homotopy=Homotopy(homotopy_parameters=[hp]), + num_restarts=2, + raw_samples=16, + fixed_features=fixed_features, + ) + self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs)) + + def test_prune_candidates(self): + tkwargs = {"device": self.device, "dtype": torch.double} + # no pruning + X = torch.rand(6, 3, **tkwargs) + vals = X.sum(dim=-1) + X_pruned = prune_candidates(candidates=X, acq_values=vals, prune_tolerance=1e-6) + self.assertTrue((X[vals.argsort(descending=True), :] == X_pruned).all()) + # pruning + X[1, :] = X[0, :] + 1e-10 + X[4, :] = X[2, :] - 1e-10 + vals = torch.tensor([1, 6, 3, 4, 2, 5], **tkwargs) + X_pruned = prune_candidates(candidates=X, acq_values=vals, prune_tolerance=1e-6) + self.assertTrue((X[[1, 5, 3, 2]] == X_pruned).all()) + # invalid shapes + with self.assertRaisesRegex( + ValueError, "`candidates` must be of size `n x d`." + ): + prune_candidates( + candidates=torch.zeros(3, 2, 1), + acq_values=torch.zeros(2, 1), + prune_tolerance=1e-6, + ) + with self.assertRaisesRegex(ValueError, "`acq_values` must be of size `n`."): + prune_candidates( + candidates=torch.zeros(3, 2), + acq_values=torch.zeros(3, 1), + prune_tolerance=1e-6, + ) + with self.assertRaisesRegex(ValueError, "`prune_tolerance` must be >= 0."): + prune_candidates( + candidates=torch.zeros(3, 2), + acq_values=torch.zeros(3), + prune_tolerance=-1.2345, + ) + + @mock.patch(f"{PRUNE_CANDIDATES_PATH}.prune_candidates", wraps=prune_candidates) + def test_optimize_acqf_homotopy_pruning(self, prune_candidates_mock): + tkwargs = {"device": self.device, "dtype": torch.double} + p = Parameter(torch.zeros(1, **tkwargs)) + hp = HomotopyParameter( + parameter=p, + schedule=LinearHomotopySchedule(start=4, end=0, num_steps=5), + ) + model = GenericDeterministicModel(f=lambda x: 5 - (x - p) ** 2) + acqf = PosteriorMean(model=model) + candidate, acqf_val = optimize_acqf_homotopy( + q=1, + acq_function=acqf, + bounds=torch.tensor([[-10], [5]]).to(**tkwargs), + homotopy=Homotopy(homotopy_parameters=[hp]), + num_restarts=4, + raw_samples=16, + post_processing_func=lambda x: x.round(), + ) + # First time we expect to call `prune_candidates` with 4 candidates + self.assertEqual( + prune_candidates_mock.call_args_list[0][1]["candidates"].shape, + torch.Size([4, 1]), + ) + for i in range(1, 5): # The paths should have been pruned to just one path + self.assertEqual( + prune_candidates_mock.call_args_list[i][1]["candidates"].shape, + torch.Size([1, 1]), + )