Skip to content

Commit

Permalink
Deprecate data_formats (#1067)
Browse files Browse the repository at this point in the history
* warn_deprecated

* deprecate data_format(s)

* adapt tests

* document

* keywords internal for warn_deprecated

* avoid unnecessary warnings in tests

* fix tests

* NEWS entry

---------

Co-authored-by: Michel Lang <[email protected]>
  • Loading branch information
mb706 and mllg authored Aug 21, 2024
1 parent 35392fb commit d72be13
Show file tree
Hide file tree
Showing 34 changed files with 290 additions and 237 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ RoxygenNote: 7.3.2
Collate:
'mlr_reflections.R'
'BenchmarkResult.R'
'warn_deprecated.R'
'DataBackend.R'
'DataBackendCbind.R'
'DataBackendDataTable.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ export(convert_task)
export(create_empty_prediction_data)
export(data.table)
export(default_measures)
export(deprecated_binding)
export(extract_pkgs)
export(filter_prediction_data)
export(install_pkgs)
Expand Down Expand Up @@ -244,6 +245,7 @@ export(tgens)
export(tsk)
export(tsks)
export(unmarshal_model)
export(warn_deprecated)
import(checkmate)
import(data.table)
import(mlr3misc)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mlr3 (development version)

* Deprecated `data_format` and `data_formats` for Learners, Tasks, and DataBackends.
* feat: The `partition()` function creates training, test and validation sets.
* refactor: Optimize runtime of fixing factor levels.
* refactor: Optimize runtime of setting row roles.
Expand Down
14 changes: 8 additions & 6 deletions R/DataBackend.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#' @title DataBackend
#'
#' @include mlr_reflections.R
#' @include warn_deprecated.R
#'
#' @description
#' This is the abstract base class for data backends.
Expand Down Expand Up @@ -42,10 +43,6 @@ DataBackend = R6Class("DataBackend", cloneable = FALSE,
#' Column name of the primary key column of positive and unique integer row ids.
primary_key = NULL,

#' @field data_formats (`character()`)\cr
#' Set of supported formats, e.g. `"data.table"` or `"Matrix"`.
data_formats = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
Expand All @@ -62,10 +59,10 @@ DataBackend = R6Class("DataBackend", cloneable = FALSE,
#' Each DataBackend needs a way to address rows, which is done via a
#' column of unique integer values, referenced here by `primary_key`. The
#' use of this variable may differ between backends.
initialize = function(data, primary_key, data_formats = "data.table") {
initialize = function(data, primary_key, data_formats) {
private$.data = data
self$primary_key = assert_string(primary_key)
self$data_formats = assert_subset(data_formats, mlr_reflections$data_formats, empty.ok = FALSE)
if (!missing(data_formats)) warn_deprecated("DataBackend$initialize argument 'data_formats'")
},

#' @description
Expand All @@ -88,6 +85,11 @@ DataBackend = R6Class("DataBackend", cloneable = FALSE,
),

active = list(
#' @field data_formats (`character()`)\cr
#' Supported data format. Always `"data.table"`..
#' This is deprecated and will be removed in the future.
data_formats = deprecated_binding("DataBackend$data_formats", "data.table"),

#' @template field_hash
hash = function(rhs) {
if (missing(rhs)) {
Expand Down
15 changes: 5 additions & 10 deletions R/DataBackendCbind.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,23 @@ DataBackendCbind = R6Class("DataBackendCbind", inherit = DataBackend, cloneable
assert_backend(b2)
pk = b1$primary_key

data_formats = intersect(b1$data_formats, b2$data_formats)
if ("data.table" %nin% data_formats) {
stopf("There is supported data format for the backends to cbind (supported: 'data.table')")
}

if (pk != b2$primary_key) {
stopf("All backends to cbind must have the primary_key '%s'", pk)
}

super$initialize(list(b1 = b1, b2 = b2), pk, "data.table")
super$initialize(list(b1 = b1, b2 = b2), pk)
},

data = function(rows, cols, data_format = "data.table") {
data = function(rows, cols, data_format) {
pk = self$primary_key
qrows = unique(assert_numeric(rows))
qcols = union(assert_names(cols, type = "unique"), pk)
assert_choice(data_format, self$data_formats)
if (!missing(data_format)) warn_deprecated("DataBackendCbind$data argument 'data_format'")

data = private$.data$b2$data(qrows, qcols, data_format = data_format)
data = private$.data$b2$data(qrows, qcols)
if (ncol(data) < length(qcols)) {
qcols = c(setdiff(cols, names(data)), pk)
tmp = private$.data$b1$data(qrows, qcols, data_format = data_format)
tmp = private$.data$b1$data(qrows, qcols)
data = merge(data, tmp, by = pk, all = TRUE, sort = TRUE)
}

Expand Down
6 changes: 3 additions & 3 deletions R/DataBackendDataTable.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ DataBackendDataTable = R6Class("DataBackendDataTable", inherit = DataBackend,
#' The input [data.table()].
initialize = function(data, primary_key) {
assert_data_table(data, col.names = "unique")
super$initialize(setkeyv(data, primary_key), primary_key, data_formats = "data.table")
super$initialize(setkeyv(data, primary_key), primary_key)
ii = match(primary_key, names(data))
if (is.na(ii)) {
stopf("Primary key '%s' not in 'data'", primary_key)
Expand All @@ -60,10 +60,10 @@ DataBackendDataTable = R6Class("DataBackendDataTable", inherit = DataBackend,
#' Queries for rows with no matching row id and queries for columns with no matching column name are silently ignored.
#' Rows are guaranteed to be returned in the same order as `rows`, columns may be returned in an arbitrary order.
#' Duplicated row ids result in duplicated rows, duplicated column names lead to an exception.
data = function(rows, cols, data_format = "data.table") {
data = function(rows, cols, data_format) {
rows = assert_integerish(rows, coerce = TRUE)
assert_names(cols, type = "unique")
assert_choice(data_format, self$data_formats)
if (!missing(data_format)) warn_deprecated("DataBackendDataTable$data argument 'data_format'")
cols = intersect(cols, colnames(private$.data))

if (self$compact_seq) {
Expand Down
55 changes: 8 additions & 47 deletions R/DataBackendMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
#'
#' b = as_data_backend(data, dense = dense, primary_key = "..row_id")
#' b$head()
#' b$data(1:3, b$colnames, data_format = "Matrix")
#' b$data(1:3, b$colnames, data_format = "data.table")
#' b$data(1:3, b$colnames)
DataBackendMatrix = R6Class("DataBackendMatrix", inherit = DataBackend, cloneable = FALSE,
public = list(

Expand All @@ -52,20 +51,20 @@ DataBackendMatrix = R6Class("DataBackendMatrix", inherit = DataBackend, cloneabl

assert_disjunct(colnames(data), colnames(dense))

super$initialize(data = list(sparse = data, dense = as.data.table(dense)), primary_key, data_formats = c("Matrix", "data.table"))
super$initialize(data = list(sparse = data, dense = as.data.table(dense)), primary_key)
},

#' @description
#' Returns a slice of the data in the specified format.
#' Currently, the only supported formats are `"data.table"` and `"Matrix"`.
#' Returns a slice of the data as `"data.table"`.
#' The rows must be addressed as vector of primary key values, columns must be referred to via column names.
#' Queries for rows with no matching row id and queries for columns with no matching column name are silently ignored.
#' Rows are guaranteed to be returned in the same order as `rows`, columns may be returned in an arbitrary order.
#' Duplicated row ids result in duplicated rows, duplicated column names lead to an exception.
data = function(rows, cols, data_format = "data.table") {
data = function(rows, cols, data_format) {
assert_integerish(rows, coerce = TRUE)
assert_names(cols, type = "unique")
assert_choice(data_format, self$data_formats)

if (!missing(data_format)) warn_deprecated("DataBackendMatrix$data argument 'data_format'")

rows = private$.translate_rows(rows)
cols_sparse = intersect(cols, colnames(private$.data$sparse))
Expand All @@ -74,46 +73,8 @@ DataBackendMatrix = R6Class("DataBackendMatrix", inherit = DataBackend, cloneabl
sparse = private$.data$sparse[rows, cols_sparse, drop = FALSE]
dense = private$.data$dense[rows, cols_dense, with = FALSE]

if (data_format == "data.table") {
data = cbind(as.data.table(as.matrix(sparse)), dense)
setcolorder(data, intersect(cols, names(data)))
} else {
qassertr(dense, c("n", "f"))

factors = names(which(map_lgl(dense, is.factor)))
if (length(factors)) {
# create list of dummy matrices
dummies = imap(dense[, factors, with = FALSE], function(x, nn) {
if (nlevels(x) > 1L) {
contrasts = contr.treatment(levels(x), sparse = TRUE)
X = contrasts[match(x, rownames(contrasts), nomatch = 0L), , drop = FALSE]
colnames(X) = sprintf("%s_%s", nn, colnames(contrasts))
} else {
X = matrix(rep(1, nrow(dense)), ncol = 1L)
colnames(X) = sprintf("%s_%s", nn, levels(x))
}
X
})

replace_with = function(x, needle, replacement) {
ii = (x == needle)
x = rep(x, 1L + (length(replacement) - 1L) * ii)
replace(x, ii, replacement)
}

# update the column vector with new dummy names (this preserves the order)
cols = Reduce(function(cols, name) replace_with(cols, name, colnames(dummies[[name]])),
names(dummies), init = cols)
dense = remove_named(dense, factors)
} else {
dummies = NULL
}

dense = if (nrow(dense)) as.matrix(dense) else NULL
data = do.call(cbind, c(list(sparse, dense), dummies))
data[, match(cols, colnames(data), nomatch = 0L), drop = FALSE]
}

data = cbind(as.data.table(as.matrix(sparse)), dense)
setcolorder(data, intersect(cols, names(data)))
data
},

Expand Down
15 changes: 5 additions & 10 deletions R/DataBackendRbind.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,23 @@ DataBackendRbind = R6Class("DataBackendRbind", inherit = DataBackend, cloneable
assert_backend(b2)
pk = b1$primary_key

data_formats = intersect(b1$data_formats, b2$data_formats)
if ("data.table" %nin% data_formats) {
stopf("There is supported data format for the backends to cbind (supported: 'data.table')")
}

if (pk != b2$primary_key) {
stopf("All backends to rbind must have the primary_key '%s'", pk)
}

super$initialize(list(b1 = b1, b2 = b2), pk, "data.table")
super$initialize(list(b1 = b1, b2 = b2), pk)
},

data = function(rows, cols, data_format = "data.table") {
data = function(rows, cols, data_format) {
pk = self$primary_key
qrows = unique(assert_numeric(rows))
qcols = union(assert_names(cols, type = "unique"), pk)
assert_choice(data_format, self$data_formats)
if (!missing(data_format)) warn_deprecated("DataBackendRbind$data argument 'data_format'")

data = private$.data$b2$data(qrows, qcols, data_format = data_format)
data = private$.data$b2$data(qrows, qcols)
if (nrow(data) < length(qrows)) {
qrows = setdiff(rows, data[[pk]])
tmp = private$.data$b1$data(qrows, qcols, data_format = data_format)
tmp = private$.data$b1$data(qrows, qcols)
data = rbindlist(list(data, tmp), use.names = TRUE, fill = TRUE)
}

Expand Down
7 changes: 4 additions & 3 deletions R/DataBackendRename.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ DataBackendRename = R6Class("DataBackendRename", inherit = DataBackend, cloneabl
new = NULL,

initialize = function(b, old, new) {
super$initialize(data = b, b$primary_key, "data.table")
super$initialize(data = b, b$primary_key)
assert_character(old, any.missing = FALSE, unique = TRUE)
assert_subset(old, b$colnames)
assert_character(new, any.missing = FALSE, len = length(old))
Expand All @@ -30,11 +30,12 @@ DataBackendRename = R6Class("DataBackendRename", inherit = DataBackend, cloneabl
self$new = new
},

data = function(rows, cols, data_format = self$data_formats[1L]) {
data = function(rows, cols, data_format) {
assert_names(cols, type = "unique")
b = private$.data
cols = map_values(intersect(cols, self$colnames), self$new, self$old)
data = b$data(rows, cols, data_format)
if (!missing(data_format)) warn_deprecated("DataBackendRename$data argument 'data_format'")
data = b$data(rows, cols)
set_col_names(data, map_values(names(data), self$old, self$new))
},

Expand Down
15 changes: 9 additions & 6 deletions R/Learner.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#' @title Learner Class
#'
#' @include mlr_reflections.R
#' @include warn_deprecated.R
#'
#' @description
#' This is the abstract base class for learner objects like [LearnerClassif] and [LearnerRegr].
Expand Down Expand Up @@ -174,10 +175,6 @@ Learner = R6Class("Learner",
#' A complete list of candidate properties, grouped by task type, is stored in [`mlr_reflections$learner_properties`][mlr_reflections].
properties = NULL,

#' @field data_formats (`character()`)\cr
#' Supported data format, e.g. `"data.table"` or `"Matrix"`.
data_formats = NULL,

#' @template field_packages
packages = NULL,

Expand Down Expand Up @@ -213,7 +210,7 @@ Learner = R6Class("Learner",
#'
#' Note that this object is typically constructed via a derived classes, e.g. [LearnerClassif] or [LearnerRegr].
initialize = function(id, task_type, param_set = ps(), predict_types = character(), feature_types = character(),
properties = character(), data_formats = "data.table", packages = character(), label = NA_character_, man = NA_character_) {
properties = character(), data_formats, packages = character(), label = NA_character_, man = NA_character_) {

self$id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
Expand All @@ -224,7 +221,7 @@ Learner = R6Class("Learner",
empty.ok = FALSE, .var.name = "predict_types")
private$.predict_type = predict_types[1L]
self$properties = sort(assert_subset(properties, mlr_reflections$learner_properties[[task_type]]))
self$data_formats = assert_subset(data_formats, mlr_reflections$data_formats)
if (!missing(data_formats)) warn_deprecated("Learner$initialize argument 'data_formats'")
self$packages = union("mlr3", assert_character(packages, any.missing = FALSE, min.chars = 1L))
self$man = assert_string(man, na.ok = TRUE)

Expand Down Expand Up @@ -456,6 +453,12 @@ Learner = R6Class("Learner",
),

active = list(
#' @field data_formats (`character()`)\cr
#' Supported data format. Always `"data.table"`..
#' This is deprecated and will be removed in the future.
data_formats = deprecated_binding("Learner$data_formats", "data.table"),


#' @field model (any)\cr
#' The fitted model. Only available after `$train()` has been called.
model = function(rhs) {
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ LearnerClassif = R6Class("LearnerClassif", inherit = Learner,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id, param_set = ps(), predict_types = "response", feature_types = character(), properties = character(), data_formats = "data.table", packages = character(), label = NA_character_, man = NA_character_) {
initialize = function(id, param_set = ps(), predict_types = "response", feature_types = character(), properties = character(), data_formats, packages = character(), label = NA_character_, man = NA_character_) {
super$initialize(id = id, task_type = "classif", param_set = param_set, predict_types = predict_types,
feature_types = feature_types, properties = properties, data_formats = data_formats, packages = packages,
label = label, man = man)
Expand Down
1 change: 0 additions & 1 deletion R/LearnerClassifDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
predict_types = c("response", "prob"),
properties = c("twoclass", "multiclass", "missings", "hotstart_forward", "validation", "internal_tuning", "marshal"),
man = "mlr3::mlr_learners_classif.debug",
data_formats = c("data.table", "Matrix"),
label = "Debug Learner for Classification"
)
},
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ LearnerRegr = R6Class("LearnerRegr", inherit = Learner,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id, param_set = ps(), predict_types = "response", feature_types = character(), properties = character(), data_formats = "data.table", packages = character(), label = NA_character_, man = NA_character_) {
initialize = function(id, param_set = ps(), predict_types = "response", feature_types = character(), properties = character(), data_formats, packages = character(), label = NA_character_, man = NA_character_) {
super$initialize(id = id, task_type = "regr", param_set = param_set, feature_types = feature_types,
predict_types = predict_types, properties = properties, data_formats = data_formats, packages = packages,
predict_types = predict_types, properties = properties, data_formats, packages = packages,
label = label, man = man)
}
),
Expand Down
1 change: 0 additions & 1 deletion R/LearnerRegrDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
),
properties = "missings",
man = "mlr3::mlr_learners_regr.debug",
data_formats = c("data.table", "Matrix"),
label = "Debug Learner for Regression"
)
}
Expand Down
Loading

0 comments on commit d72be13

Please sign in to comment.