Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add learner, resampling and measure weights #1124

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Depends:
R (>= 3.1.0)
Imports:
R6 (>= 2.4.1),
backports,
backports (>= 1.5.0),
checkmate (>= 2.0.0),
data.table (>= 1.15.0),
evaluate,
Expand Down
17 changes: 11 additions & 6 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
# mlr3 (development version)

* Deprecated `data_format` and `data_formats` for Learners, Tasks, and DataBackends.
* refactor: It is now possible to use weights also during scoring predictions via measures and during resampling to sample observations with unequal probability.
The weights must be stored in the task and can be assigned the column role `weights_measure` or `weights_resampling`, respectively.
The weights used during training by the Learner are renamed to `weights_learner`, the previous column role `weight` is dysfunctional.
Additionally, it is now possible to disable the use of weights via the new hyperparameter `use_weights`.
Note that this is a breaking change, but appears to be the less error-prone solution in the long run.
* refactor: Deprecated `data_format` and `data_formats` for Learners, Tasks, and DataBackends.
* feat: The `partition()` function creates training, test and validation sets.
* refactor: Optimize runtime of fixing factor levels.
* refactor: Optimize runtime of setting row roles.
* refactor: Optimize runtime of marshalling.
* refactor: Optimize runtime of `Task$col_info`.
* fix: column info is now checked for compatibility during `Learner$predict` (#943).
* BREAKING CHANGE: the predict time of the learner now stores the cumulative duration for all predict sets (#992).
* fix: Column info is now checked for compatibility during `Learner$predict` (#943).
* BREAKING CHANGE: The predict time of the learner now stores the cumulative duration for all predict sets (#992).
* feat: `$internal_valid_task` can now be set to an `integer` vector.
* feat: Measures can now have an empty `$predict_sets` (#1094).
this is relevant for measures that only extract information from
the model of a learner (such as internal validation scores or AIC / BIC)
This is relevant for measures that only extract information from the model of a learner (such as internal validation scores or AIC / BIC)
* refactor: Deprecated the `$divide()` method
* fix: `Task$cbind()` now works with non-standard primary keys for `data.frames` (#961).
* fix: Triggering of fallback learner now has log-level `"info"` instead of `"debug"` (#972).
* feat: Added new measure `pinballs `.
* feat: Added new measure `mu_auc`.
* feat: Add option to calculate the mean of the true values on the train set in `msr("regr.rsq")`.
* feat: default fallback learner is set when encapsulation is activated.
* feat: Default fallback learner is set when encapsulation is activated.
* feat: Learners classif.debug and regr.debug have new methods `$importance()` and `$selected_features()` for testing, also in downstream packages

# mlr3 0.20.2

Expand Down
54 changes: 52 additions & 2 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@
#' Only available for [`Learner`]s with the `"internal_tuning"` property.
#' If the learner is not trained yet, this returns `NULL`.
#'
#' @section Weights:
#'
#' Many learners support observation weights, indicated by their property `"weights"`.
#' The weights are stored in the [Task] where the column role `weights_learner` needs to be assigned to a single numeric column.
#' The weights are automatically used if found in the task, this can be disabled by setting the hyperparamerter `use_weights` to `FALSE`.
#' If the learner is set-up to use weights but the task does not have a designated weight column, an unweighted version is calculated instead.
#' The weights do not necessarily need to sum up to 1, they are passed down to the learner.
#'
#' @section Setting Hyperparameters:
#'
#' All information about hyperparameters is stored in the slot `param_set` which is a [paradox::ParamSet].
Expand Down Expand Up @@ -215,7 +223,6 @@ Learner = R6Class("Learner",
self$id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = assert_choice(task_type, mlr_reflections$task_types$type)
private$.param_set = assert_param_set(param_set)
self$feature_types = assert_ordered_set(feature_types, mlr_reflections$task_feature_types, .var.name = "feature_types")
self$predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]),
empty.ok = FALSE, .var.name = "predict_types")
Expand All @@ -225,6 +232,13 @@ Learner = R6Class("Learner",
self$packages = union("mlr3", assert_character(packages, any.missing = FALSE, min.chars = 1L))
self$man = assert_string(man, na.ok = TRUE)

if ("weights" %in% self$properties) {
self$use_weights = "use"
} else {
self$use_weights = "error"
}
private$.param_set = param_set

check_packages_installed(packages, msg = sprintf("Package '%%s' required but not installed for Learner '%s'", id))
},

Expand Down Expand Up @@ -405,7 +419,7 @@ Learner = R6Class("Learner",
assert_names(newdata$colnames, must.include = task$feature_names)

# the following columns are automatically set to NA if missing
impute = unlist(task$col_roles[c("target", "name", "order", "stratum", "group", "weight")], use.names = FALSE)
impute = unlist(task$col_roles[c("target", "name", "order", "stratum", "group", "weights_learner", "weights_measure", "weights_resampling")], use.names = FALSE)
impute = setdiff(impute, newdata$colnames)
if (length(impute)) {
# create list with correct NA types and cbind it to the backend
Expand Down Expand Up @@ -453,6 +467,25 @@ Learner = R6Class("Learner",
),

active = list(
#' @field use_weights (`character(1)`)\cr
#' How to use weights.
#' Settings are `"use"` `"ignore"`, and `"error"`.
#'
#' * `"use"`: use weights, as supported by the underlying `Learner`.
#' * `"ignore"`: do not use weights.
#' * `"error"`: throw an error if weights are present in the training `Task`.
#'
#' For `Learner`s with the property `"weights_learner"`, this is initialized as `"use"`.
#' For `Learner`s that do not support weights, i.e. without the `"weights_learner"` property, this is initialized as `"error"`.
#' This behaviour is to avoid cases where a user erroneously assumes that a `Learner` supports weights when it does not.
use_weights = function(rhs) {
if (!missing(rhs)) {
assert_choice(rhs, c(if ("weights" %in% self$properties) "use", "ignore", "error"))
private$.use_weights = rhs
}
private$.use_weights
},

#' @field data_formats (`character()`)\cr
#' Supported data format. Always `"data.table"`..
#' This is deprecated and will be removed in the future.
Expand Down Expand Up @@ -613,12 +646,29 @@ Learner = R6Class("Learner",
),

private = list(
.use_weights = NULL,
.encapsulate = NULL,
.fallback = NULL,
.predict_type = NULL,
.param_set = NULL,
.hotstart_stack = NULL,

# retrieve weights from a task, if it has weights and if the user did not
# deactivate weight usage through `self$use_weights`.
# - `task`: Task to retrieve weights from
# - `no_weights_val`: Value to return if no weights are found (default NULL)
# return: Numeric vector of weights or `no_weights_val` (default NULL)
.get_weights = function(task, no_weights_val = NULL) {
if ("weights" %nin% self$properties) {
stop("private$.get_weights should not be used in Learners that do not have the 'weights_learner' property.")
}
if (self$use_weights == "use" && "weights_learner" %in% task$properties) {
task$weights_learner$weight
} else {
no_weights_val
}
},

deep_clone = function(name, value) {
switch(name,
.param_set = value$clone(deep = TRUE),
Expand Down
32 changes: 30 additions & 2 deletions R/LearnerClassifDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,27 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
#' Additional arguments passed to [`unmarshal_model()`].
unmarshal = function(...) {
learner_unmarshal(.learner = self, ...)
},

#' @description
#' Returns 0 for each feature seen in training.
#' @return Named `numeric()`.
importance = function() {
if (is.null(self$model)) {
stopf("No model stored")
}
fns = self$state$feature_names
set_names(rep(0, length(fns)), fns)
},

#' @description
#' Always returns character(0).
#' @return `character()`.
selected_features = function() {
if (is.null(self$model)) {
stopf("No model stored")
}
character(0)
}
),
active = list(
Expand Down Expand Up @@ -169,8 +190,15 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
stopf("Early stopping is only possible when a validation task is present.")
}

model = list(response = as.character(sample(task$truth(), 1L)), pid = Sys.getpid(), id = UUIDgenerate(),
random_number = sample(100000, 1), iter = if (isTRUE(pv$early_stopping)) sample(pv$iter %??% 1L, 1L) else pv$iter %??% 1L
model = list(
response = as.character(sample(task$truth(), 1L)),
pid = Sys.getpid(),
id = UUIDgenerate(),
random_number = sample(100000, 1),
iter = if (isTRUE(pv$early_stopping))
sample(pv$iter %??% 1L, 1L)
else
pv$iter %??% 1L
)

if (!is.null(valid_truth)) {
Expand Down
10 changes: 4 additions & 6 deletions R/LearnerClassifRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#'
#' @section Initial parameter values:
#' * Parameter `xval` is initialized to 0 in order to save some computation time.
#' * Parameter `use_weights` can be set to `FALSE` to ignore observation weights with column role `weights_learner` ,
#' if present.
#'
#' @section Custom mlr3 parameters:
#' * Parameter `model` has been renamed to `keep_model`.
Expand Down Expand Up @@ -35,9 +37,8 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
minsplit = p_int(1L, default = 20L, tags = "train"),
surrogatestyle = p_int(0L, 1L, default = 0L, tags = "train"),
usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"),
xval = p_int(0L, default = 10L, tags = "train")
xval = p_int(0L, default = 10L, init = 0L, tags = "train")
)
ps$values = list(xval = 0L)

super$initialize(
id = "classif.rpart",
Expand Down Expand Up @@ -77,10 +78,7 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
if ("weights" %in% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}

pv$weights = private$.get_weights(task)
invoke(rpart::rpart, formula = task$formula(), data = task$data(), .args = pv, .opts = allow_partial_matching)
},

Expand Down
21 changes: 21 additions & 0 deletions R/LearnerRegrDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,27 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
man = "mlr3::mlr_learners_regr.debug",
label = "Debug Learner for Regression"
)
},

#' @description
#' Returns 0 for each feature seen in training.
#' @return Named `numeric()`.
importance = function() {
if (is.null(self$model)) {
stopf("No model stored")
}
fns = self$state$feature_names
set_names(rep(0, length(fns)), fns)
},

#' @description
#' Always returns character(0).
#' @return `character()`.
selected_features = function() {
if (is.null(self$model)) {
stopf("No model stored")
}
character(0)
}
),
private = list(
Expand Down
10 changes: 4 additions & 6 deletions R/LearnerRegrRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#'
#' @section Initial parameter values:
#' * Parameter `xval` is initialized to 0 in order to save some computation time.
#' * Parameter `use_weights` can be set to `FALSE` to ignore observation weights with column role `weights_learner` ,
#' if present.
#'
#' @section Custom mlr3 parameters:
#' * Parameter `model` has been renamed to `keep_model`.
Expand Down Expand Up @@ -35,9 +37,8 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr,
minsplit = p_int(1L, default = 20L, tags = "train"),
surrogatestyle = p_int(0L, 1L, default = 0L, tags = "train"),
usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"),
xval = p_int(0L, default = 10L, tags = "train")
xval = p_int(0L, default = 10L, init = 0L, tags = "train")
)
ps$values = list(xval = 0L)

super$initialize(
id = "regr.rpart",
Expand Down Expand Up @@ -77,10 +78,7 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr,
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
if ("weights" %in% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}

pv$weights = private$.get_weights(task)
invoke(rpart::rpart, formula = task$formula(), data = task$data(), .args = pv, .opts = allow_partial_matching)
},

Expand Down
41 changes: 29 additions & 12 deletions R/Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
#' In such cases it is necessary to overwrite the public methods `$aggregate()` and/or `$score()` to return a named `numeric()`
#' where at least one of its names corresponds to the `id` of the measure itself.
#'
#' @section Weights:
#'
#' Many measures support observation weights, indicated by their property `"weights"`.
#' The weights are stored in the [Task] where the column role `weights_measure` needs to be assigned to a single numeric column.
#' The weights are automatically used if found in the task, this can be disabled by setting the hyperparamerter `use_weights` to `FALSE`.
#' If the measure is set-up to use weights but the task does not have a designated weight column, an unweighted version is calculated instead.
#' The weights do not necessarily need to sum up to 1, they are normalized by dividing by the sum of weights.
#'
#' @template param_id
#' @template param_param_set
#' @template param_range
Expand Down Expand Up @@ -94,10 +102,6 @@ Measure = R6Class("Measure",
#' Lower and upper bound of possible performance scores.
range = NULL,

#' @field properties (`character()`)\cr
#' Properties of this measure.
properties = NULL,

#' @field minimize (`logical(1)`)\cr
#' If `TRUE`, good predictions correspond to small values of performance scores.
minimize = NULL,
Expand All @@ -117,7 +121,6 @@ Measure = R6Class("Measure",
predict_sets = "test", task_properties = character(), packages = character(),
label = NA_character_, man = NA_character_, trafo = NULL) {

self$properties = unique(properties)
self$id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = task_type
Expand All @@ -140,6 +143,8 @@ Measure = R6Class("Measure",
assert_subset(task_properties, mlr_reflections$task_properties[[task_type]])
}


self$properties = unique(properties)
self$predict_type = predict_type
self$predict_sets = predict_sets
self$task_properties = task_properties
Expand Down Expand Up @@ -195,24 +200,25 @@ Measure = R6Class("Measure",
#' @return `numeric(1)`.
score = function(prediction, task = NULL, learner = NULL, train_set = NULL) {
assert_measure(self, task = task, learner = learner)
assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% self$properties)
properties = self$properties
assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% properties)

if ("requires_task" %in% self$properties && is.null(task)) {
if ("requires_task" %in% properties && is.null(task)) {
stopf("Measure '%s' requires a task", self$id)
}

if ("requires_learner" %in% self$properties && is.null(learner)) {
if ("requires_learner" %in% properties && is.null(learner)) {
stopf("Measure '%s' requires a learner", self$id)
}

if ("requires_model" %in% self$properties && (is.null(learner) || is.null(learner$model))) {
if ("requires_model" %in% properties && (is.null(learner) || is.null(learner$model))) {
stopf("Measure '%s' requires the trained model", self$id)
}
if ("requires_model" %in% self$properties && is_marshaled_model(learner$model)) {
if ("requires_model" %in% properties && is_marshaled_model(learner$model)) {
stopf("Measure '%s' requires the trained model, but model is in marshaled form", self$id)
}

if ("requires_train_set" %in% self$properties && is.null(train_set)) {
if ("requires_train_set" %in% properties && is.null(train_set)) {
stopf("Measure '%s' requires the train_set", self$id)
}

Expand All @@ -231,7 +237,6 @@ Measure = R6Class("Measure",
#'
#' @return `numeric(1)`.
aggregate = function(rr) {

switch(self$average,
"macro" = {
aggregator = self$aggregator %??% mean
Expand Down Expand Up @@ -275,6 +280,17 @@ Measure = R6Class("Measure",
self$predict_sets, mget(private$.extra_hash, envir = self))
},

#' @field properties (`character()`)\cr
#' Properties of this measure.
properties = function(rhs) {
if (!missing(rhs)) {
props = if (is.na(self$task_type)) unique(unlist(mlr_reflections$measure_properties), use.names = FALSE) else mlr_reflections$measure_properties[[self$task_type]]
private$.properties = assert_subset(rhs, props)
} else {
private$.properties
}
},

#' @field average (`character(1)`)\cr
#' Method for aggregation:
#'
Expand Down Expand Up @@ -307,6 +323,7 @@ Measure = R6Class("Measure",
),

private = list(
.properties = character(),
.predict_sets = NULL,
.extra_hash = character(),
.average = NULL,
Expand Down
Loading
Loading