Skip to content

Commit

Permalink
small corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jul 11, 2023
1 parent df0b7ae commit 755b105
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 16 deletions.
15 changes: 4 additions & 11 deletions R/LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
#' See [`mlr_reflections$learner_predict_types`][mlr_reflections] for available values.
#' For regression, the default is `"response"`.
#' For classification, this defaults to `"response"` and `"prob"`.
#' To deviate from the defaults, it is necessary to overwrite the private `$.predict()` method.
#' To deviate from the defaults, it is necessary to overwrite the private `$.encode_prediction()`
#' method, see section *Inheriting*.
#' @param loss (`NULL` or [`TorchLoss`])\cr
#' The loss to use for training.
#' Defaults to MSE for regression and cross entropy for classification.
Expand Down Expand Up @@ -71,7 +72,7 @@
#' ([`torch_tensor`], [`Task`], `list()`) -> `list()`\cr
#' Take in the raw predictions from `self$network` (`predict_tensor`) and encode them into a
#' format that can be converted to valid `mlr3` predictions using [`mlr3::as_prediction_data()`].
#' This must take `self$predict_type` into account.
#' This method must take `self$predict_type` into account.
#'
#' While it is possible to add parameters by specifying the `param_set` construction argument, it is currently
#' not possible to remove existing parameters, i.e. those listed in section *Parameters*.
Expand All @@ -88,15 +89,6 @@ LearnerTorch = R6Class("LearnerTorch",
initialize = function(id, task_type, param_set, properties, man, label, feature_types,
optimizer = NULL, loss = NULL, packages = NULL, predict_types = NULL, callbacks = list()) {
assert_choice(task_type, c("regr", "classif"))
predict_types = predict_types %??% switch(task_type,
regr = "response",
classif = c("response", "prob")
)
loss = loss %??% switch(task_type,
classif = t_loss("cross_entropy"),
regr = t_loss("mse")
)
optimizer = optimizer %??% t_opt("adam")

learner_torch_initialize(self = self, private = private, super = super,
task_type = task_type,
Expand All @@ -118,6 +110,7 @@ LearnerTorch = R6Class("LearnerTorch",
#' @field network ([`nn_module()`][torch::nn_module])\cr
#' The network (only available after training).
network = function(rhs) {
assert_ro_binding(rhs)
if (is.null(self$state)) {
stopf("Cannot access network before training.")
}
Expand Down
1 change: 0 additions & 1 deletion R/LearnerTorchAlexNet.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ LearnerTorchAlexNet = R6Class("LearnerTorchAlexNet",
pretrained = p_lgl(tags = c("required", "train"))
)
param_set$values = list(pretrained = TRUE)
# TODO: Freezing --> maybe as a callback?
super$initialize(
task_type = task_type,
id = paste0(task_type, ".alexnet"),
Expand Down
1 change: 0 additions & 1 deletion R/LearnerTorchFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ LearnerTorchFeatureless = R6Class("LearnerTorchFeatureless",
label = "Featureless Torch Learner",
param_set = ps(),
properties = properties,
# TODO: This should have all feature types, and have prop"),
feature_types = unname(mlr_reflections$task_feature_types),
man = "mlr3torch::mlr_learners.torch_featureless",
optimizer = optimizer,
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerTorchImage.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ LearnerTorchImage = R6Class("LearnerTorchImage",
callbacks = list(), packages = c("torchvision", "magick"), man, properties = NULL,
predict_types = NULL) {
properties = properties %??% switch(task_type,
regr = c(),
regr = character(0),
classif = c("twoclass", "multiclass")
)
assert_param_set(param_set)
Expand Down
10 changes: 10 additions & 0 deletions R/learner_torch_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ learner_torch_initialize = function(
label,
callbacks
) {
predict_types = predict_types %??% switch(task_type,
regr = "response",
classif = c("response", "prob")
)
loss = loss %??% switch(task_type,
classif = t_loss("cross_entropy"),
regr = t_loss("mse")
)
optimizer = optimizer %??% t_opt("adam")

private$.optimizer = as_torch_optimizer(optimizer, clone = TRUE)
private$.optimizer$param_set$set_id = "opt"

Expand Down
5 changes: 3 additions & 2 deletions man/mlr_learners_torch.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 755b105

Please sign in to comment.