Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support base_margin for xgboost #309

Merged
merged 13 commits into from
Sep 6, 2024
11 changes: 11 additions & 0 deletions R/LearnerRegrXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
alpha = p_dbl(0, default = 0, tags = "train"),
approxcontrib = p_lgl(default = FALSE, tags = "predict"),
base_score = p_dbl(default = 0.5, tags = "train"),
base_margin = p_uty(default = NULL, special_vals = list(NULL), tags = "train"),
booster = p_fct(c("gbtree", "gblinear", "dart"), default = "gbtree", tags = "train"),
callbacks = p_uty(default = list(), tags = "train"),
colsample_bylevel = p_dbl(0, 1, default = 1, tags = "train"),
Expand Down Expand Up @@ -206,6 +207,13 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
xgboost::setinfo(data, "weight", task$weights$weight)
}

bm = pv$base_margin
pv$base_margin = NULL # silence xgb.train message
bm_is_feature = !is.null(bm) && is.character(bm) && (bm %in% task$feature_names)
if (bm_is_feature) {
xgboost::setinfo(data, "base_margin", task$data(cols = bm)[[1L]])
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
}

# the last element in the watchlist is used as the early stopping set
internal_valid_task = task$internal_valid_task
if (!is.null(pv$early_stopping_rounds) && is.null(internal_valid_task)) {
Expand All @@ -215,6 +223,9 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
test_data = internal_valid_task$data(cols = task$feature_names)
test_target = internal_valid_task$data(cols = task$target_names)
test_data = xgboost::xgb.DMatrix(data = as_numeric_matrix(test_data), label = data.matrix(test_target))
if (bm_is_feature) {
xgboost::setinfo(test_data, "base_margin", internal_valid_task$data(cols = bm)[[1L]])
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
}
pv$watchlist = c(pv$watchlist, list(test = test_data))
}

Expand Down
3 changes: 2 additions & 1 deletion inst/paramtest/test_paramtest_regr.xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ test_that("regr.xgboost", {
"label", # handled by mlr3
"weight", # handled by mlr3
"nthread", # handled by mlr3
"feval" # handled via eval_metric parameter
"feval", # handled via eval_metric parameter
"base_margin" # handled by mlr3
)

ParamTest = run_paramtest(learner, fun, exclude, tag = "train")
Expand Down