Skip to content

Commit

Permalink
feat: allow EI to be adjusted by epsilon to strengthen exploration (#154
Browse files Browse the repository at this point in the history
)
  • Loading branch information
sumny authored Aug 13, 2024
1 parent b45f868 commit 2d99077
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 8 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Config/testthat/edition: 3
Config/testthat/parallel: false
NeedsCompilation: yes
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Collate:
'mlr_acqfunctions.R'
'AcqFunction.R'
Expand Down
4 changes: 2 additions & 2 deletions R/AcqFunctionCB.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ AcqFunctionCB = R6Class("AcqFunctionCB",
constants = list(...)
lambda = constants$lambda
p = self$surrogate$predict(xdt)
res = p$mean - self$surrogate_max_to_min * lambda * p$se
data.table(acq_cb = res)
cb = p$mean - self$surrogate_max_to_min * lambda * p$se
data.table(acq_cb = cb)
}
)
)
Expand Down
23 changes: 19 additions & 4 deletions R/AcqFunctionEI.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
#' @description
#' Expected Improvement.
#'
#' @section Parameters:
#' * `"epsilon"` (`numeric(1)`)\cr
#' \eqn{\epsilon} value used to determine the amount of exploration.
#' Higher values result in the importance of improvements predicted by the posterior mean
#' decreasing relative to the importance of potential improvements in regions of high predictive uncertainty.
#' Defaults to `0` (standard Expected Improvement).
#'
#' @references
#' * `r format_bib("jones_1998")`
#'
Expand Down Expand Up @@ -60,9 +67,15 @@ AcqFunctionEI = R6Class("AcqFunctionEI",
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @param surrogate (`NULL` | [SurrogateLearner]).
initialize = function(surrogate = NULL) {
#' @param epsilon (`numeric(1)`).
initialize = function(surrogate = NULL, epsilon = 0) {
assert_r6(surrogate, "SurrogateLearner", null.ok = TRUE)
super$initialize("acq_ei", surrogate = surrogate, requires_predict_type_se = TRUE, direction = "maximize", label = "Expected Improvement", man = "mlr3mbo::mlr_acqfunctions_ei")
assert_number(epsilon, lower = 0, finite = TRUE)

constants = ps(epsilon = p_dbl(lower = 0, default = 0))
constants$values$epsilon = epsilon

super$initialize("acq_ei", constants = constants, surrogate = surrogate, requires_predict_type_se = TRUE, direction = "maximize", label = "Expected Improvement", man = "mlr3mbo::mlr_acqfunctions_ei")
},

#' @description
Expand All @@ -73,14 +86,16 @@ AcqFunctionEI = R6Class("AcqFunctionEI",
),

private = list(
.fun = function(xdt) {
.fun = function(xdt, ...) {
if (is.null(self$y_best)) {
stop("$y_best is not set. Missed to call $update()?")
}
constants = list(...)
epsilon = constants$epsilon
p = self$surrogate$predict(xdt)
mu = p$mean
se = p$se
d = self$y_best - self$surrogate_max_to_min * mu
d = (self$y_best - self$surrogate_max_to_min * mu) - epsilon
d_norm = d / se
ei = d * pnorm(d_norm) + se * dnorm(d_norm)
ei = ifelse(se < 1e-20, 0, ei)
Expand Down
15 changes: 14 additions & 1 deletion man/mlr_acqfunctions_ei.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions tests/testthat/test_AcqFunctionCB.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ test_that("AcqFunctionCB works", {
expect_learner(acqf$surrogate$learner)
expect_true(acqf$requires_predict_type_se)

expect_r6(acqf$constants, "ParamSet")
expect_equal(acqf$constants$ids(), "lambda")

design = MAKE_DESIGN(inst)
inst$eval_batch(design)

Expand Down
3 changes: 3 additions & 0 deletions tests/testthat/test_AcqFunctionEHVIGH.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ test_that("AcqFunctionEHVIGH works", {
expect_true(acqf$requires_predict_type_se)
expect_setequal(acqf$packages, c("emoa", "fastGHQuad"))

expect_r6(acqf$constants, "ParamSet")
expect_equal(acqf$constants$ids(), c("k", "r"))

design = MAKE_DESIGN(inst)
inst$eval_batch(design)

Expand Down
3 changes: 3 additions & 0 deletions tests/testthat/test_AcqFunctionEI.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ test_that("AcqFunctionEI works", {
expect_learner(acqf$surrogate$learner)
expect_true(acqf$requires_predict_type_se)

expect_r6(acqf$constants, "ParamSet")
expect_equal(acqf$constants$ids(), "epsilon")

design = MAKE_DESIGN(inst)
inst$eval_batch(design)

Expand Down
3 changes: 3 additions & 0 deletions tests/testthat/test_AcqFunctionSmsEgo.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ test_that("AcqFunctionSmsEgo works", {
expect_list(acqf$surrogate$learner, types = "Learner")
expect_true(acqf$requires_predict_type_se)

expect_r6(acqf$constants, "ParamSet")
expect_equal(acqf$constants$ids(), c("lambda", "epsilon"))

design = MAKE_DESIGN(inst)
inst$eval_batch(design)

Expand Down

0 comments on commit 2d99077

Please sign in to comment.