Skip to content

Commit

Permalink
move sparseBO to OSS (#1915)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1915

X-link: facebookexternal/botorch_fb#12

X-link: facebook/Ax#1676

Move SEBO-L0/L1 to OSS
- Move homotopy from botorch/fb to botorch

Reviewed By: esantorella

Differential Revision: D46528626

fbshipit-source-id: 904ddf38eb3b93fe6fd9d1ae8bde520239849052
  • Loading branch information
Qing Feng authored and facebook-github-bot committed Jul 5, 2023
1 parent 7eb847a commit 8fcbd6b
Show file tree
Hide file tree
Showing 6 changed files with 654 additions and 1 deletion.
39 changes: 38 additions & 1 deletion botorch/acquisition/penalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from __future__ import annotations

import math
from typing import Callable, List, Optional
from typing import Any, Callable, List, Optional

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
Expand Down Expand Up @@ -139,6 +139,43 @@ def forward(self, X: Tensor) -> Tensor:
return regularization_term


def narrow_gaussian(X: Tensor, a: Tensor) -> Tensor:
return torch.exp(-0.5 * (X / a) ** 2)


def nnz_approx(X: Tensor, target_point: Tensor, a: Tensor) -> Tensor:
r"""Differentiable relaxation of ||X - target_point||_0
Args:
X: An `n x d` tensor of inputs.
target_point: A tensor of size `n` corresponding to the target point.
a: A scalar tensor that controls the differentiable relaxation.
"""
d = X.shape[-1]
if d != target_point.shape[-1]:
raise ValueError("X and target_point have different shapes.")
return d - narrow_gaussian(X - target_point, a).sum(dim=-1, keepdim=True)


class L0Approximation(torch.nn.Module):
r"""Differentiable relaxation of the L0 norm using a Gaussian basis function."""

def __init__(self, target_point: Tensor, a: float = 1.0, **tkwargs: Any) -> None:
r"""Initializing L0 penalty with differentiable relaxation.
Args:
target_point: A tensor corresponding to the target point.
a: A hyperparameter that controls the differentiable relaxation.
"""
super().__init__()
self.target_point = target_point
# hyperparameter to control the differentiable relaxation in L0 norm function.
self.register_buffer("a", torch.tensor(a, **tkwargs))

def __call__(self, X: Tensor) -> Tensor:
return nnz_approx(X=X, target_point=self.target_point, a=self.a)


class PenalizedAcquisitionFunction(AcquisitionFunction):
r"""Single-outcome acquisition function regularized by the given penalty.
Expand Down
14 changes: 14 additions & 0 deletions botorch/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
scipy_minimize,
torch_minimize,
)
from botorch.optim.homotopy import (
FixedHomotopySchedule,
Homotopy,
HomotopyParameter,
LinearHomotopySchedule,
LogLinearHomotopySchedule,
)
from botorch.optim.initializers import initialize_q_batch, initialize_q_batch_nonneg
from botorch.optim.numpy_converter import module_to_array, set_params_with_array
from botorch.optim.optimize import (
Expand All @@ -25,6 +32,7 @@
optimize_acqf_discrete_local_search,
optimize_acqf_mixed,
)
from botorch.optim.optimize_homotopy import optimize_acqf_homotopy
from botorch.optim.stopping import ExpMAStoppingCriterion


Expand All @@ -42,9 +50,15 @@
"optimize_acqf_discrete",
"optimize_acqf_discrete_local_search",
"optimize_acqf_mixed",
"optimize_acqf_homotopy",
"module_to_array",
"scipy_minimize",
"set_params_with_array",
"torch_minimize",
"ExpMAStoppingCriterion",
"FixedHomotopySchedule",
"Homotopy",
"HomotopyParameter",
"LinearHomotopySchedule",
"LogLinearHomotopySchedule",
]
191 changes: 191 additions & 0 deletions botorch/optim/homotopy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import math
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Union

import torch
from torch import Tensor
from torch.nn import Parameter


class HomotopySchedule(ABC):
@property
@abstractmethod
def num_steps(self) -> int:
"""Number of steps in the schedule."""
pass

@property
@abstractmethod
def value(self) -> Any:
"""Current value in the schedule."""
pass

@property
@abstractmethod
def should_stop(self) -> bool:
"""Return true if we have incremented past the end of the schedule."""
pass

@abstractmethod
def restart(self) -> None:
"""Restart the schedule to start from the beginning."""
pass

@abstractmethod
def step(self) -> None:
"""Move to solving the next problem."""
pass


class FixedHomotopySchedule(HomotopySchedule):
"""Homotopy schedule with a fixed list of values."""

def __init__(self, values: List[Any]) -> None:
r"""Initialize FixedHomotopySchedule.
Args:
values: A list of values used in homotopy
"""
self._values = values
self.idx = 0

@property
def num_steps(self) -> int:
return len(self._values)

@property
def value(self) -> Any:
return self._values[self.idx]

@property
def should_stop(self) -> bool:
return self.idx == len(self._values)

def restart(self) -> None:
self.idx = 0

def step(self) -> None:
self.idx += 1


class LinearHomotopySchedule(FixedHomotopySchedule):
"""Linear homotopy schedule."""

def __init__(self, start: float, end: float, num_steps: int) -> None:
r"""Initialize LinearHomotopySchedule.
Args:
start: start value of homotopy
end: end value of homotopy
num_steps: number of steps in the homotopy schedule.
"""
super().__init__(
values=torch.linspace(start, end, num_steps, dtype=torch.double).tolist()
)


class LogLinearHomotopySchedule(FixedHomotopySchedule):
"""Log-linear homotopy schedule."""

def __init__(self, start: float, end: float, num_steps: int):
r"""Initialize LogLinearHomotopySchedule.
Args:
start: start value of homotopy
end: end value of homotopy
num_steps: number of steps in the homotopy schedule.
"""
super().__init__(
values=torch.logspace(
math.log10(start), math.log10(end), num_steps, dtype=torch.double
).tolist()
)


@dataclass
class HomotopyParameter:
r"""Homotopy parameter.
The parameter is expected to either be a torch parameter or a torch tensor which may
correspond to a buffer of a module. The parameter has a corresponding schedule.
"""
parameter: Union[Parameter, Tensor]
schedule: HomotopySchedule


class Homotopy:
"""Generic homotopy class.
This class is designed to be used in `optimize_acqf_homotopy`. Given a set of
homotopy parameters and corresponding schedules we step through the homotopies
until we have solved the final problem. We additionally support passing in a list
of callbacks that will be executed each time `step`, `reset`, and `restart` are
called.
"""

def __init__(
self,
homotopy_parameters: List[HomotopyParameter],
callbacks: Optional[List[Callable]] = None,
) -> None:
r"""Initialize the homotopy.
Args:
homotopy_parameters: List of homotopy parameters
callbacks: Optional list of callbacks that are executed each time
`restart`, `reset`, or `step` are called. These may be used to, e.g.,
reinitialize the acquisition function which is needed when using qNEHVI.
"""
# TODO: Check inputs
self._homotopy_parameters = homotopy_parameters
self._callbacks = callbacks or []
self._original_values = [
hp.parameter.item() for hp in self._homotopy_parameters
]
assert all(
isinstance(hp.parameter, Parameter) or isinstance(hp.parameter, Tensor)
for hp in self._homotopy_parameters
)
# Assume the same number of steps for now
assert len({h.schedule.num_steps for h in self._homotopy_parameters}) == 1
# Initialize the homotopy parameters
self.restart()

def _execute_callbacks(self) -> None:
"""Execute the callbacks."""
for callback in self._callbacks:
callback()

@property
def should_stop(self) -> bool:
"""Returns true if all schedules have reached the end."""
return all(h.schedule.should_stop for h in self._homotopy_parameters)

def restart(self) -> None:
"""Restart the homotopy to use the initial value in the schedule."""
for hp in self._homotopy_parameters:
hp.schedule.restart()
hp.parameter.data.fill_(hp.schedule.value)
self._execute_callbacks()

def reset(self) -> None:
"""Reset the homotopy parameter to their original values."""
for hp, val in zip(self._homotopy_parameters, self._original_values):
hp.parameter.data.fill_(val)
self._execute_callbacks()

def step(self) -> None:
"""Take a step according to the schedules."""
for hp in self._homotopy_parameters:
hp.schedule.step()
if not hp.schedule.should_stop:
hp.parameter.data.fill_(hp.schedule.value)
self._execute_callbacks()
Loading

0 comments on commit 8fcbd6b

Please sign in to comment.