diff --git a/botorch/acquisition/penalized.py b/botorch/acquisition/penalized.py index 6b929a8c96..b5529a4e3e 100644 --- a/botorch/acquisition/penalized.py +++ b/botorch/acquisition/penalized.py @@ -176,6 +176,29 @@ def __call__(self, X: Tensor) -> Tensor: return nnz_approx(X=X, target_point=self.target_point, a=self.a) +class L0PenaltyApprox(L0Approximation): + r"""Differentiable relaxation of the L0 norm to be added to any arbitrary acquisition + function to construct a PenalizedAcquisitionFunction.""" + + def __init__(self, target_point: Tensor, a: float = 1.0, **tkwargs: Any) -> None: + r"""Initializing L0 penalty with differentiable relaxation. + + Args: + target_point: A tensor corresponding to the target point. + a: A hyperparameter that controls the differentiable relaxation. + """ + super().__init__(target_point=target_point, a=a, **tkwargs) + + def __call__(self, X: Tensor) -> Tensor: + r""" + Args: + X: A "batch_shape x q x dim" representing the points to be evaluated. + Returns: + A tensor of size "batch_shape" representing the acqfn for each q-batch. + """ + return super().__call__(X=X).squeeze(dim=-1).min(dim=-1).values + + class PenalizedAcquisitionFunction(AcquisitionFunction): r"""Single-outcome acquisition function regularized by the given penalty. @@ -297,6 +320,7 @@ def __init__( objective: Callable[[Tensor, Optional[Tensor]], Tensor], penalty_objective: torch.nn.Module, regularization_parameter: float, + expand_dim: Optional[int] = None, ) -> None: r"""Penalized MC objective. @@ -309,10 +333,13 @@ def __init__( `batch-shape x q x d`-dim Tensor `X` and outputs a `1 x batch-shape x q`-dim Tensor of penalty objective values. regularization_parameter: weight of the penalty (regularization) term + expand_dim: dim to expand penalty_objective to match with objective when fully + bayesian model is used. If None, no expansion is performed. """ super().__init__(objective=objective) self.penalty_objective = penalty_objective self.regularization_parameter = regularization_parameter + self.expand_dim = expand_dim def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor: r"""Evaluate the penalized objective on the samples. @@ -329,4 +356,35 @@ def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor: """ obj = super().forward(samples=samples, X=X) penalty_obj = self.penalty_objective(X) + # when fully bayesian model is used, we pass unmarginalize_dim to match the shape + # between obj `sample_shape x batch-shape x mcmc_samples x q` and penalty_obj `1 x batch-shape x q` + if self.expand_dim is not None: + # reshape penalty_obj to match the dim + penalty_obj = penalty_obj.unsqueeze(self.expand_dim) return obj - self.regularization_parameter * penalty_obj + + +class L0PenaltyApproxObjective(L0Approximation): + r"""Differentiable relaxation of the L0 norm penalty objective class. + An instance of this class can be added to any arbitrary objective to + construct a PenalizedMCObjective. + """ + + def __init__(self, target_point: Tensor, a: float = 1.0, **tkwargs: Any) -> None: + r"""Initializing L0 penalty with differentiable relaxation. + + Args: + target_point: A tensor corresponding to the target point. + a: A hyperparameter that controls the differentiable relaxation. + """ + super().__init__(target_point=target_point, a=a, **tkwargs) + + def __call__(self, X: Tensor) -> Tensor: + r""" + Args: + X: A "batch_shape x q x dim" representing the points to be evaluated. + Returns: + A "1 x batch_shape x q" tensor representing the penalty for each point. + The first dimension corresponds to the dimension of MC samples. + """ + return super().__call__(X=X).squeeze(dim=-1).unsqueeze(dim=0) diff --git a/test/acquisition/test_penalized.py b/test/acquisition/test_penalized.py index 67578ef975..ae451645fc 100644 --- a/test/acquisition/test_penalized.py +++ b/test/acquisition/test_penalized.py @@ -12,6 +12,8 @@ group_lasso_regularizer, GroupLassoPenalty, L0Approximation, + L0PenaltyApprox, + L0PenaltyApproxObjective, L1Penalty, L1PenaltyObjective, L2Penalty, @@ -151,6 +153,58 @@ def test_L0Approximation(self): rtol=1e-04, ) + def test_L0PenaltyApproxObjective(self): + for dtype in (torch.float, torch.double): + tkwargs = {"device": self.device, "dtype": dtype} + target_point = torch.zeros(2, **tkwargs) + + # test init + l0_obj = L0PenaltyApproxObjective(target_point=target_point, **tkwargs) + self.assertTrue(torch.equal(l0_obj.target_point, target_point)) + self.assertAllClose(l0_obj.a.data, torch.tensor(1.0, **tkwargs)) + + # check two-dim input tensors X + self.assertTrue( + torch.equal( + l0_obj(torch.zeros(3, 2, **tkwargs)).data, + torch.zeros(1, 3, **tkwargs), + ) + ) + # check "batch_shape x q x dim" input tensors X + batch_shape = 16 + self.assertTrue( + torch.equal( + l0_obj(torch.zeros(batch_shape, 3, 2, **tkwargs)).data, + torch.zeros(1, batch_shape, 3, **tkwargs), + ) + ) + + def test_L0PenaltyApprox(self): + for dtype in (torch.float, torch.double): + tkwargs = {"device": self.device, "dtype": dtype} + target_point = torch.zeros(2, **tkwargs) + + # test init + l0_acqf = L0PenaltyApprox(target_point=target_point, **tkwargs) + self.assertTrue(torch.equal(l0_acqf.target_point, target_point)) + self.assertAllClose(l0_acqf.a.data, torch.tensor(1.0, **tkwargs)) + + # check two-dim input tensors X + self.assertTrue( + torch.equal( + l0_acqf(torch.zeros(3, 2, **tkwargs)).data, + torch.tensor(0, **tkwargs), + ) + ) + # check "batch_shape x q x dim" input tensors X + batch_shape = 16 + self.assertTrue( + torch.equal( + l0_acqf(torch.zeros(batch_shape, 3, 2, **tkwargs)).data, + torch.zeros(batch_shape, **tkwargs), + ) + ) + class TestPenalizedAcquisitionFunction(BotorchTestCase): def test_penalized_acquisition_function(self): @@ -231,6 +285,8 @@ def test_penalized_mc_objective(self): penalty_objective=l1_penalty_obj, regularization_parameter=0.1, ) + # test self.expand_dim + self.assertIsNone(obj.expand_dim) # test 'd' Tensor X samples = torch.randn(4, 3, device=self.device, dtype=dtype) X = torch.randn(4, 5, device=self.device, dtype=dtype) @@ -246,3 +302,42 @@ def test_penalized_mc_objective(self): X = torch.randn(3, 2, 5, device=self.device, dtype=dtype) penalized_obj = generic_obj(samples) - 0.1 * l1_penalty_obj(X) self.assertTrue(torch.equal(obj(samples, X), penalized_obj)) + + # test passing expand_dim + expand_dim = -2 + obj2 = PenalizedMCObjective( + objective=generic_obj, + penalty_objective=l1_penalty_obj, + regularization_parameter=0.1, + expand_dim=expand_dim, + ) + self.assertEqual(obj2.expand_dim, -2) + # test 'd' Tensor X + mcmc_samples = 8 + # MCMC_dim = -3 + samples = torch.randn(mcmc_samples, 4, 3, device=self.device, dtype=dtype) + X = torch.randn(4, 5, device=self.device, dtype=dtype) + penalized_obj = generic_obj(samples) - 0.1 * l1_penalty_obj(X).unsqueeze( + expand_dim + ) + self.assertTrue(torch.equal(obj2(samples, X), penalized_obj)) + # test 'q x d' Tensor X + # MCMC_dim = -3 + samples = torch.randn( + 4, mcmc_samples, 2, 3, device=self.device, dtype=dtype + ) + X = torch.randn(2, 5, device=self.device, dtype=dtype) + penalized_obj = generic_obj(samples) - 0.1 * l1_penalty_obj(X).unsqueeze( + expand_dim + ) + self.assertTrue(torch.equal(obj2(samples, X), penalized_obj)) + # test 'batch-shape x q x d' Tensor X + # MCMC_dim = -3 + samples = torch.randn( + 4, 3, mcmc_samples, 2, 3, device=self.device, dtype=dtype + ) + X = torch.randn(3, 2, 5, device=self.device, dtype=dtype) + penalized_obj = generic_obj(samples) - 0.1 * l1_penalty_obj(X).unsqueeze( + expand_dim + ) + self.assertTrue(torch.equal(obj2(samples, X), penalized_obj))