Skip to content

Commit

Permalink
mlr3 upkeep
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Oct 7, 2024
1 parent 9fc13e3 commit 2fa8369
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 28 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mlr3torch dev

* Don't use deprecated `data_formats` anymore

# mlr3torch 0.1.1

* fix(preprocessing): regarding the construction of some `PipeOps` such as `po("trafo_resize")`
Expand Down
13 changes: 4 additions & 9 deletions R/DataBackendLazy.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@
#' Desired data format, e.g. `"data.table"` or `"Matrix"`.
#' @param na_rm (`logical(1)`)\cr
#' Whether to remove NAs or not.
#' @param data_formats (`character()`)\cr
#' Set of supported data formats. E.g. `"data.table"`.
#' These must be a subset of the data formats of the lazily constructed backend.
#' @param primary_key (`character(1)`)\cr
#' Name of the primary key column.
#'
Expand All @@ -72,7 +69,6 @@
#' constructor = constructor,
#' rownames = 1:10,
#' col_info = column_info,
#' data_formats = "data.table",
#' primary_key = "row_id"
#' )
#'
Expand All @@ -94,7 +90,7 @@ DataBackendLazy = R6Class("DataBackendLazy",
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(constructor, rownames, col_info, primary_key, data_formats) {
initialize = function(constructor, rownames, col_info, primary_key) {
private$.rownames = assert_integerish(rownames, unique = TRUE)
private$.col_info = assert_data_table(col_info, ncols = 3, min.rows = 1)
assert_permutation(colnames(col_info), c("id", "type", "levels"))
Expand All @@ -103,7 +99,7 @@ DataBackendLazy = R6Class("DataBackendLazy",
assert_choice(primary_key, col_info$id)
private$.constructor = assert_function(constructor, args = "backend")

super$initialize(data = NULL, primary_key = primary_key, data_formats = data_formats)
super$initialize(data = NULL, primary_key = primary_key)
},

#' @description
Expand All @@ -114,8 +110,8 @@ DataBackendLazy = R6Class("DataBackendLazy",
#' Duplicated row ids result in duplicated rows, duplicated column names lead to an exception.
#'
#' Accessing the data triggers the construction of the backend.
data = function(rows, cols, data_format = "data.table") {
self$backend$data(rows = rows, cols = cols, data_format = data_format)
data = function(rows, cols) {
self$backend$data(rows = rows, cols = cols)
},

#' @description
Expand Down Expand Up @@ -190,7 +186,6 @@ DataBackendLazy = R6Class("DataBackendLazy",
f(test_permutation, backend$colnames, private$.colnames, "column names")
f(test_equal_col_info, col_info(backend), private$.col_info, "column information")
# need to reverse the order for correct error message
f(function(x, y) test_subset(y, x), backend$data_formats, self$data_formats, "data formats")
private$.backend = backend
}
private$.backend
Expand Down
1 change: 0 additions & 1 deletion R/LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ LearnerTorch = R6Class("LearnerTorch",
param_set = self$param_set,
predict_types = predict_types,
properties = properties,
data_formats = "data.table",
label = label,
feature_types = feature_types,
man = man
Expand Down
3 changes: 1 addition & 2 deletions R/TaskClassif_mnist.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ load_task_mnist = function(id = "mnist") {
constructor = cached_constructor,
rownames = seq_len(70000),
col_info = load_col_info("mnist"),
primary_key = "..row_id",
data_formats = "data.table"
primary_key = "..row_id"
)

task = TaskClassif$new(
Expand Down
3 changes: 1 addition & 2 deletions R/TaskClassif_tiny_imagenet.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ load_task_tiny_imagenet = function(id = "tiny_imagenet") {
constructor = cached_constructor,
rownames = seq_len(120000),
col_info = load_col_info("tiny_imagenet"),
primary_key = "..row_id",
data_formats = "data.table"
primary_key = "..row_id"
)

task = TaskClassif$new(
Expand Down
15 changes: 5 additions & 10 deletions tests/testthat/test_DataBackendLazy.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ test_that("DataBackendLazy works", {
constructor = constructor,
rownames = n:1,
col_info = column_info,
primary_key = "row_id",
data_formats = "data.table"
primary_key = "row_id"
)

expect_r6(backend_lazy, c("DataBackend", "DataBackendLazy"))
Expand Down Expand Up @@ -97,8 +96,7 @@ test_that("DataBackendLazy works", {
constructor = constructor,
rownames = 1:10,
col_info = col_info,
primary_key = "a",
data_formats = "data.table"
primary_key = "a"
)$backend,
regexp = regexp
)
Expand All @@ -114,8 +112,7 @@ test_that("primary_key must be in col_info", {
constructor = function(backend) NULL,
col_info = data.table(id = "a", type = "integer", levels = list(NULL)),
rownames = 1,
primary_key = "b",
data_formats = "data.table"
primary_key = "b"
), regexp = "Must be element of")
})

Expand All @@ -130,8 +127,7 @@ test_that("primary_key must be the same for backends", {
constructor = constructor,
col_info = data.table(id = c("x", "y"), type = rep("integer", 2), levels = list(NULL, NULL)),
rownames = 1:5,
primary_key = "y",
data_formats = "data.table"
primary_key = "y"
)
expect_error(backend_lazy$backend, "primary key")
})
Expand All @@ -141,7 +137,6 @@ test_that("constructor must have argument backend", {
constructor = function() NULL,
col_info = data.table(id = c("x", "y"), type = rep("integer", 2), levels = list(NULL, NULL)),
rownames = 1:5,
primary_key = "y",
data_formats = "data.table"
primary_key = "y"
), regexp = "formal arguments")
})
8 changes: 5 additions & 3 deletions tests/testthat/test_LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,10 @@ test_that("resample() works", {

test_that("marshaling", {
task = tsk("mtcars")$filter(1:5)
learner = lrn("regr.mlp", batch_size = 150, epochs = 1, device = "cpu", encapsulate = c(train = "callr"),
learner = lrn("regr.mlp", batch_size = 150, epochs = 1, device = "cpu",
neurons = 20
)
learner$encapsulate("callr", lrn("regr.featureless"))
learner$train(task)
expect_false(learner$marshaled)
learner$marshal()$unmarshal()
Expand All @@ -431,9 +432,10 @@ test_that("marshaling", {
test_that("callr encapsulation and marshaling", {
skip_if_not_installed("callr")
task = tsk("mtcars")$filter(1:5)
learner = lrn("regr.mlp", batch_size = 150, epochs = 1, device = "cpu", encapsulate = c(train = "callr"),
learner = lrn("regr.mlp", batch_size = 150, epochs = 1, device = "cpu",
neurons = 20
)
learner$encapsulate("callr", lrn("regr.featureless"))
learner$train(task)
expect_prediction(learner$predict(task))
})
Expand Down Expand Up @@ -461,7 +463,7 @@ test_that("Input verification works during `$train()` (train-predict shapes work
)

# fallback learner cannot help in this case!
learner$fallback = lrn("classif.featureless")
learner$encapsulate("evaluate", fallback = lrn("classif.featureless"))
rr_faulty = resample(task_invalid, learner, rsmp("holdout"))
expect_true(nrow(rr_faulty$errors) == 1L)
rr1 = resample(task, learner, rsmp("holdout"))
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_PipeOpTorchModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ test_that("base_learner works", {
po("torch_model_regr")

glrn = as_learner(graph)
expect_equal(glrn$base_learner(return_po = TRUE)$id, "torch_model_regr")
expect_equal(glrn$base_learner(return_po = TRUE, recursive = 1)$id, "torch_model_regr")
})

test_that("internal_tuning", {
Expand Down

0 comments on commit 2fa8369

Please sign in to comment.