-
Notifications
You must be signed in to change notification settings - Fork 393
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #1915 X-link: facebookexternal/botorch_fb#12 X-link: facebook/Ax#1676 Move SEBO-L0/L1 to OSS - Move homotopy from botorch/fb to botorch Reviewed By: esantorella Differential Revision: D46528626 fbshipit-source-id: 904ddf38eb3b93fe6fd9d1ae8bde520239849052
- Loading branch information
1 parent
7eb847a
commit 8fcbd6b
Showing
6 changed files
with
654 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from __future__ import annotations | ||
|
||
import math | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, List, Optional, Union | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch.nn import Parameter | ||
|
||
|
||
class HomotopySchedule(ABC): | ||
@property | ||
@abstractmethod | ||
def num_steps(self) -> int: | ||
"""Number of steps in the schedule.""" | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def value(self) -> Any: | ||
"""Current value in the schedule.""" | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def should_stop(self) -> bool: | ||
"""Return true if we have incremented past the end of the schedule.""" | ||
pass | ||
|
||
@abstractmethod | ||
def restart(self) -> None: | ||
"""Restart the schedule to start from the beginning.""" | ||
pass | ||
|
||
@abstractmethod | ||
def step(self) -> None: | ||
"""Move to solving the next problem.""" | ||
pass | ||
|
||
|
||
class FixedHomotopySchedule(HomotopySchedule): | ||
"""Homotopy schedule with a fixed list of values.""" | ||
|
||
def __init__(self, values: List[Any]) -> None: | ||
r"""Initialize FixedHomotopySchedule. | ||
Args: | ||
values: A list of values used in homotopy | ||
""" | ||
self._values = values | ||
self.idx = 0 | ||
|
||
@property | ||
def num_steps(self) -> int: | ||
return len(self._values) | ||
|
||
@property | ||
def value(self) -> Any: | ||
return self._values[self.idx] | ||
|
||
@property | ||
def should_stop(self) -> bool: | ||
return self.idx == len(self._values) | ||
|
||
def restart(self) -> None: | ||
self.idx = 0 | ||
|
||
def step(self) -> None: | ||
self.idx += 1 | ||
|
||
|
||
class LinearHomotopySchedule(FixedHomotopySchedule): | ||
"""Linear homotopy schedule.""" | ||
|
||
def __init__(self, start: float, end: float, num_steps: int) -> None: | ||
r"""Initialize LinearHomotopySchedule. | ||
Args: | ||
start: start value of homotopy | ||
end: end value of homotopy | ||
num_steps: number of steps in the homotopy schedule. | ||
""" | ||
super().__init__( | ||
values=torch.linspace(start, end, num_steps, dtype=torch.double).tolist() | ||
) | ||
|
||
|
||
class LogLinearHomotopySchedule(FixedHomotopySchedule): | ||
"""Log-linear homotopy schedule.""" | ||
|
||
def __init__(self, start: float, end: float, num_steps: int): | ||
r"""Initialize LogLinearHomotopySchedule. | ||
Args: | ||
start: start value of homotopy | ||
end: end value of homotopy | ||
num_steps: number of steps in the homotopy schedule. | ||
""" | ||
super().__init__( | ||
values=torch.logspace( | ||
math.log10(start), math.log10(end), num_steps, dtype=torch.double | ||
).tolist() | ||
) | ||
|
||
|
||
@dataclass | ||
class HomotopyParameter: | ||
r"""Homotopy parameter. | ||
The parameter is expected to either be a torch parameter or a torch tensor which may | ||
correspond to a buffer of a module. The parameter has a corresponding schedule. | ||
""" | ||
parameter: Union[Parameter, Tensor] | ||
schedule: HomotopySchedule | ||
|
||
|
||
class Homotopy: | ||
"""Generic homotopy class. | ||
This class is designed to be used in `optimize_acqf_homotopy`. Given a set of | ||
homotopy parameters and corresponding schedules we step through the homotopies | ||
until we have solved the final problem. We additionally support passing in a list | ||
of callbacks that will be executed each time `step`, `reset`, and `restart` are | ||
called. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
homotopy_parameters: List[HomotopyParameter], | ||
callbacks: Optional[List[Callable]] = None, | ||
) -> None: | ||
r"""Initialize the homotopy. | ||
Args: | ||
homotopy_parameters: List of homotopy parameters | ||
callbacks: Optional list of callbacks that are executed each time | ||
`restart`, `reset`, or `step` are called. These may be used to, e.g., | ||
reinitialize the acquisition function which is needed when using qNEHVI. | ||
""" | ||
# TODO: Check inputs | ||
self._homotopy_parameters = homotopy_parameters | ||
self._callbacks = callbacks or [] | ||
self._original_values = [ | ||
hp.parameter.item() for hp in self._homotopy_parameters | ||
] | ||
assert all( | ||
isinstance(hp.parameter, Parameter) or isinstance(hp.parameter, Tensor) | ||
for hp in self._homotopy_parameters | ||
) | ||
# Assume the same number of steps for now | ||
assert len({h.schedule.num_steps for h in self._homotopy_parameters}) == 1 | ||
# Initialize the homotopy parameters | ||
self.restart() | ||
|
||
def _execute_callbacks(self) -> None: | ||
"""Execute the callbacks.""" | ||
for callback in self._callbacks: | ||
callback() | ||
|
||
@property | ||
def should_stop(self) -> bool: | ||
"""Returns true if all schedules have reached the end.""" | ||
return all(h.schedule.should_stop for h in self._homotopy_parameters) | ||
|
||
def restart(self) -> None: | ||
"""Restart the homotopy to use the initial value in the schedule.""" | ||
for hp in self._homotopy_parameters: | ||
hp.schedule.restart() | ||
hp.parameter.data.fill_(hp.schedule.value) | ||
self._execute_callbacks() | ||
|
||
def reset(self) -> None: | ||
"""Reset the homotopy parameter to their original values.""" | ||
for hp, val in zip(self._homotopy_parameters, self._original_values): | ||
hp.parameter.data.fill_(val) | ||
self._execute_callbacks() | ||
|
||
def step(self) -> None: | ||
"""Take a step according to the schedules.""" | ||
for hp in self._homotopy_parameters: | ||
hp.schedule.step() | ||
if not hp.schedule.should_stop: | ||
hp.parameter.data.fill_(hp.schedule.value) | ||
self._execute_callbacks() |
Oops, something went wrong.