Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move sparseBO to OSS #1915

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion botorch/acquisition/penalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from __future__ import annotations

import math
from typing import Callable, List, Optional
from typing import Any, Callable, List, Optional

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
Expand Down Expand Up @@ -139,6 +139,43 @@ def forward(self, X: Tensor) -> Tensor:
return regularization_term


def narrow_gaussian(X: Tensor, a: Tensor) -> Tensor:
return torch.exp(-0.5 * (X / a) ** 2)


def nnz_approx(X: Tensor, target_point: Tensor, a: Tensor) -> Tensor:
r"""Differentiable relaxation of ||X - target_point||_0
Args:
X: An `n x d` tensor of inputs.
target_point: A tensor of size `n` corresponding to the target point.
a: A scalar tensor that controls the differentiable relaxation.
"""
d = X.shape[-1]
if d != target_point.shape[-1]:
raise ValueError("X and target_point have different shapes.")
return d - narrow_gaussian(X - target_point, a).sum(dim=-1, keepdim=True)


class L0Approximation(torch.nn.Module):
r"""Differentiable relaxation of the L0 norm using a Gaussian basis function."""

def __init__(self, target_point: Tensor, a: float = 1.0, **tkwargs: Any) -> None:
r"""Initializing L0 penalty with differentiable relaxation.
Args:
target_point: A tensor corresponding to the target point.
a: A hyperparameter that controls the differentiable relaxation.
"""
super().__init__()
self.target_point = target_point
# hyperparameter to control the differentiable relaxation in L0 norm function.
self.register_buffer("a", torch.tensor(a, **tkwargs))

def __call__(self, X: Tensor) -> Tensor:
return nnz_approx(X=X, target_point=self.target_point, a=self.a)


class PenalizedAcquisitionFunction(AcquisitionFunction):
r"""Single-outcome acquisition function regularized by the given penalty.
Expand Down
14 changes: 14 additions & 0 deletions botorch/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
scipy_minimize,
torch_minimize,
)
from botorch.optim.homotopy import (
FixedHomotopySchedule,
Homotopy,
HomotopyParameter,
LinearHomotopySchedule,
LogLinearHomotopySchedule,
)
from botorch.optim.initializers import initialize_q_batch, initialize_q_batch_nonneg
from botorch.optim.numpy_converter import module_to_array, set_params_with_array
from botorch.optim.optimize import (
Expand All @@ -25,6 +32,7 @@
optimize_acqf_discrete_local_search,
optimize_acqf_mixed,
)
from botorch.optim.optimize_homotopy import optimize_acqf_homotopy
from botorch.optim.stopping import ExpMAStoppingCriterion


Expand All @@ -42,9 +50,15 @@
"optimize_acqf_discrete",
"optimize_acqf_discrete_local_search",
"optimize_acqf_mixed",
"optimize_acqf_homotopy",
"module_to_array",
"scipy_minimize",
"set_params_with_array",
"torch_minimize",
"ExpMAStoppingCriterion",
"FixedHomotopySchedule",
"Homotopy",
"HomotopyParameter",
"LinearHomotopySchedule",
"LogLinearHomotopySchedule",
]
191 changes: 191 additions & 0 deletions botorch/optim/homotopy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import math
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Union

import torch
from torch import Tensor
from torch.nn import Parameter


class HomotopySchedule(ABC):
@property
@abstractmethod
def num_steps(self) -> int:
"""Number of steps in the schedule."""
pass

@property
@abstractmethod
def value(self) -> Any:
"""Current value in the schedule."""
pass

@property
@abstractmethod
def should_stop(self) -> bool:
"""Return true if we have incremented past the end of the schedule."""
pass

@abstractmethod
def restart(self) -> None:
"""Restart the schedule to start from the beginning."""
pass

@abstractmethod
def step(self) -> None:
"""Move to solving the next problem."""
pass


class FixedHomotopySchedule(HomotopySchedule):
"""Homotopy schedule with a fixed list of values."""

def __init__(self, values: List[Any]) -> None:
r"""Initialize FixedHomotopySchedule.
Args:
values: A list of values used in homotopy
"""
self._values = values
self.idx = 0

@property
def num_steps(self) -> int:
return len(self._values)

@property
def value(self) -> Any:
return self._values[self.idx]

@property
def should_stop(self) -> bool:
return self.idx == len(self._values)

def restart(self) -> None:
self.idx = 0

def step(self) -> None:
self.idx += 1


class LinearHomotopySchedule(FixedHomotopySchedule):
"""Linear homotopy schedule."""

def __init__(self, start: float, end: float, num_steps: int) -> None:
r"""Initialize LinearHomotopySchedule.
Args:
start: start value of homotopy
end: end value of homotopy
num_steps: number of steps in the homotopy schedule.
"""
super().__init__(
values=torch.linspace(start, end, num_steps, dtype=torch.double).tolist()
)


class LogLinearHomotopySchedule(FixedHomotopySchedule):
"""Log-linear homotopy schedule."""

def __init__(self, start: float, end: float, num_steps: int):
r"""Initialize LogLinearHomotopySchedule.
Args:
start: start value of homotopy
end: end value of homotopy
num_steps: number of steps in the homotopy schedule.
"""
super().__init__(
values=torch.logspace(
math.log10(start), math.log10(end), num_steps, dtype=torch.double
).tolist()
)


@dataclass
class HomotopyParameter:
r"""Homotopy parameter.
The parameter is expected to either be a torch parameter or a torch tensor which may
correspond to a buffer of a module. The parameter has a corresponding schedule.
"""
parameter: Union[Parameter, Tensor]
schedule: HomotopySchedule


class Homotopy:
"""Generic homotopy class.
This class is designed to be used in `optimize_acqf_homotopy`. Given a set of
homotopy parameters and corresponding schedules we step through the homotopies
until we have solved the final problem. We additionally support passing in a list
of callbacks that will be executed each time `step`, `reset`, and `restart` are
called.
"""

def __init__(
self,
homotopy_parameters: List[HomotopyParameter],
callbacks: Optional[List[Callable]] = None,
) -> None:
r"""Initialize the homotopy.
Args:
homotopy_parameters: List of homotopy parameters
callbacks: Optional list of callbacks that are executed each time
`restart`, `reset`, or `step` are called. These may be used to, e.g.,
reinitialize the acquisition function which is needed when using qNEHVI.
"""
# TODO: Check inputs
self._homotopy_parameters = homotopy_parameters
self._callbacks = callbacks or []
self._original_values = [
hp.parameter.item() for hp in self._homotopy_parameters
]
assert all(
isinstance(hp.parameter, Parameter) or isinstance(hp.parameter, Tensor)
for hp in self._homotopy_parameters
)
# Assume the same number of steps for now
assert len({h.schedule.num_steps for h in self._homotopy_parameters}) == 1
# Initialize the homotopy parameters
self.restart()

def _execute_callbacks(self) -> None:
"""Execute the callbacks."""
for callback in self._callbacks:
callback()

@property
def should_stop(self) -> bool:
"""Returns true if all schedules have reached the end."""
return all(h.schedule.should_stop for h in self._homotopy_parameters)

def restart(self) -> None:
"""Restart the homotopy to use the initial value in the schedule."""
for hp in self._homotopy_parameters:
hp.schedule.restart()
hp.parameter.data.fill_(hp.schedule.value)
self._execute_callbacks()

def reset(self) -> None:
"""Reset the homotopy parameter to their original values."""
for hp, val in zip(self._homotopy_parameters, self._original_values):
hp.parameter.data.fill_(val)
self._execute_callbacks()

def step(self) -> None:
"""Take a step according to the schedules."""
for hp in self._homotopy_parameters:
hp.schedule.step()
if not hp.schedule.should_stop:
hp.parameter.data.fill_(hp.schedule.value)
self._execute_callbacks()
Loading
Loading