Skip to content

Commit

Permalink
add ER/IR-L0 Acqusition function
Browse files Browse the repository at this point in the history
Summary: Implement internal regularization with L0 norm (IR-L0) and external regularization with L0 norm (ER-L0) in MBM.

Differential Revision: D46059263

fbshipit-source-id: 621f4d9aa0b593db524cff6765e2b2e3a6848a9a
  • Loading branch information
Qing Feng authored and facebook-github-bot committed Jul 5, 2023
1 parent f69a8be commit b306f29
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 0 deletions.
58 changes: 58 additions & 0 deletions botorch/acquisition/penalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,29 @@ def __call__(self, X: Tensor) -> Tensor:
return nnz_approx(X=X, target_point=self.target_point, a=self.a)


class L0PenaltyApprox(L0Approximation):
r"""Differentiable relaxation of the L0 norm to be added to any arbitrary acquisition
function to construct a PenalizedAcquisitionFunction."""

def __init__(self, target_point: Tensor, a: float = 1.0, **tkwargs: Any) -> None:
r"""Initializing L0 penalty with differentiable relaxation.
Args:
target_point: A tensor corresponding to the target point.
a: A hyperparameter that controls the differentiable relaxation.
"""
super().__init__(target_point=target_point, a=a, **tkwargs)

def __call__(self, X: Tensor) -> Tensor:
r"""
Args:
X: A "batch_shape x q x dim" representing the points to be evaluated.
Returns:
A tensor of size "batch_shape" representing the acqfn for each q-batch.
"""
return super().__call__(X=X).squeeze(dim=-1).min(dim=-1).values


class PenalizedAcquisitionFunction(AcquisitionFunction):
r"""Single-outcome acquisition function regularized by the given penalty.
Expand Down Expand Up @@ -297,6 +320,7 @@ def __init__(
objective: Callable[[Tensor, Optional[Tensor]], Tensor],
penalty_objective: torch.nn.Module,
regularization_parameter: float,
expand_dim: Optional[int] = None,
) -> None:
r"""Penalized MC objective.
Expand All @@ -309,10 +333,13 @@ def __init__(
`batch-shape x q x d`-dim Tensor `X` and outputs a
`1 x batch-shape x q`-dim Tensor of penalty objective values.
regularization_parameter: weight of the penalty (regularization) term
expand_dim: dim to expand penalty_objective to match with objective when fully
bayesian model is used. If None, no expansion is performed.
"""
super().__init__(objective=objective)
self.penalty_objective = penalty_objective
self.regularization_parameter = regularization_parameter
self.expand_dim = expand_dim

def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
r"""Evaluate the penalized objective on the samples.
Expand All @@ -329,4 +356,35 @@ def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
"""
obj = super().forward(samples=samples, X=X)
penalty_obj = self.penalty_objective(X)
# when fully bayesian model is used, we pass unmarginalize_dim to match the shape
# between obj `sample_shape x batch-shape x mcmc_samples x q` and penalty_obj `1 x batch-shape x q`
if self.expand_dim is not None:
# reshape penalty_obj to match the dim
penalty_obj = penalty_obj.unsqueeze(self.expand_dim)
return obj - self.regularization_parameter * penalty_obj


class L0PenaltyApproxObjective(L0Approximation):
r"""Differentiable relaxation of the L0 norm penalty objective class.
An instance of this class can be added to any arbitrary objective to
construct a PenalizedMCObjective.
"""

def __init__(self, target_point: Tensor, a: float = 1.0, **tkwargs: Any) -> None:
r"""Initializing L0 penalty with differentiable relaxation.
Args:
target_point: A tensor corresponding to the target point.
a: A hyperparameter that controls the differentiable relaxation.
"""
super().__init__(target_point=target_point, a=a, **tkwargs)

def __call__(self, X: Tensor) -> Tensor:
r"""
Args:
X: A "batch_shape x q x dim" representing the points to be evaluated.
Returns:
A "1 x batch_shape x q" tensor representing the penalty for each point.
The first dimension corresponds to the dimension of MC samples.
"""
return super().__call__(X=X).squeeze(dim=-1).unsqueeze(dim=0)
95 changes: 95 additions & 0 deletions test/acquisition/test_penalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
group_lasso_regularizer,
GroupLassoPenalty,
L0Approximation,
L0PenaltyApprox,
L0PenaltyApproxObjective,
L1Penalty,
L1PenaltyObjective,
L2Penalty,
Expand Down Expand Up @@ -151,6 +153,58 @@ def test_L0Approximation(self):
rtol=1e-04,
)

def test_L0PenaltyApproxObjective(self):
for dtype in (torch.float, torch.double):
tkwargs = {"device": self.device, "dtype": dtype}
target_point = torch.zeros(2, **tkwargs)

# test init
l0_obj = L0PenaltyApproxObjective(target_point=target_point, **tkwargs)
self.assertTrue(torch.equal(l0_obj.target_point, target_point))
self.assertAllClose(l0_obj.a.data, torch.tensor(1.0, **tkwargs))

# check two-dim input tensors X
self.assertTrue(
torch.equal(
l0_obj(torch.zeros(3, 2, **tkwargs)).data,
torch.zeros(1, 3, **tkwargs),
)
)
# check "batch_shape x q x dim" input tensors X
batch_shape = 16
self.assertTrue(
torch.equal(
l0_obj(torch.zeros(batch_shape, 3, 2, **tkwargs)).data,
torch.zeros(1, batch_shape, 3, **tkwargs),
)
)

def test_L0PenaltyApprox(self):
for dtype in (torch.float, torch.double):
tkwargs = {"device": self.device, "dtype": dtype}
target_point = torch.zeros(2, **tkwargs)

# test init
l0_acqf = L0PenaltyApprox(target_point=target_point, **tkwargs)
self.assertTrue(torch.equal(l0_acqf.target_point, target_point))
self.assertAllClose(l0_acqf.a.data, torch.tensor(1.0, **tkwargs))

# check two-dim input tensors X
self.assertTrue(
torch.equal(
l0_acqf(torch.zeros(3, 2, **tkwargs)).data,
torch.tensor(0, **tkwargs),
)
)
# check "batch_shape x q x dim" input tensors X
batch_shape = 16
self.assertTrue(
torch.equal(
l0_acqf(torch.zeros(batch_shape, 3, 2, **tkwargs)).data,
torch.zeros(batch_shape, **tkwargs),
)
)


class TestPenalizedAcquisitionFunction(BotorchTestCase):
def test_penalized_acquisition_function(self):
Expand Down Expand Up @@ -231,6 +285,8 @@ def test_penalized_mc_objective(self):
penalty_objective=l1_penalty_obj,
regularization_parameter=0.1,
)
# test self.expand_dim
self.assertIsNone(obj.expand_dim)
# test 'd' Tensor X
samples = torch.randn(4, 3, device=self.device, dtype=dtype)
X = torch.randn(4, 5, device=self.device, dtype=dtype)
Expand All @@ -246,3 +302,42 @@ def test_penalized_mc_objective(self):
X = torch.randn(3, 2, 5, device=self.device, dtype=dtype)
penalized_obj = generic_obj(samples) - 0.1 * l1_penalty_obj(X)
self.assertTrue(torch.equal(obj(samples, X), penalized_obj))

# test passing expand_dim
expand_dim = -2
obj2 = PenalizedMCObjective(
objective=generic_obj,
penalty_objective=l1_penalty_obj,
regularization_parameter=0.1,
expand_dim=expand_dim,
)
self.assertEqual(obj2.expand_dim, -2)
# test 'd' Tensor X
mcmc_samples = 8
# MCMC_dim = -3
samples = torch.randn(mcmc_samples, 4, 3, device=self.device, dtype=dtype)
X = torch.randn(4, 5, device=self.device, dtype=dtype)
penalized_obj = generic_obj(samples) - 0.1 * l1_penalty_obj(X).unsqueeze(
expand_dim
)
self.assertTrue(torch.equal(obj2(samples, X), penalized_obj))
# test 'q x d' Tensor X
# MCMC_dim = -3
samples = torch.randn(
4, mcmc_samples, 2, 3, device=self.device, dtype=dtype
)
X = torch.randn(2, 5, device=self.device, dtype=dtype)
penalized_obj = generic_obj(samples) - 0.1 * l1_penalty_obj(X).unsqueeze(
expand_dim
)
self.assertTrue(torch.equal(obj2(samples, X), penalized_obj))
# test 'batch-shape x q x d' Tensor X
# MCMC_dim = -3
samples = torch.randn(
4, 3, mcmc_samples, 2, 3, device=self.device, dtype=dtype
)
X = torch.randn(3, 2, 5, device=self.device, dtype=dtype)
penalized_obj = generic_obj(samples) - 0.1 * l1_penalty_obj(X).unsqueeze(
expand_dim
)
self.assertTrue(torch.equal(obj2(samples, X), penalized_obj))

0 comments on commit b306f29

Please sign in to comment.