From 755b105f0ddcf30d2ee415bd9d4ae5558f9f1909 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Tue, 11 Jul 2023 10:29:18 +0200 Subject: [PATCH] small corrections --- R/LearnerTorch.R | 15 ++++----------- R/LearnerTorchAlexNet.R | 1 - R/LearnerTorchFeatureless.R | 1 - R/LearnerTorchImage.R | 2 +- R/learner_torch_methods.R | 10 ++++++++++ man/mlr_learners_torch.Rd | 5 +++-- 6 files changed, 18 insertions(+), 16 deletions(-) diff --git a/R/LearnerTorch.R b/R/LearnerTorch.R index a7601407..e65fa3fb 100644 --- a/R/LearnerTorch.R +++ b/R/LearnerTorch.R @@ -22,7 +22,8 @@ #' See [`mlr_reflections$learner_predict_types`][mlr_reflections] for available values. #' For regression, the default is `"response"`. #' For classification, this defaults to `"response"` and `"prob"`. -#' To deviate from the defaults, it is necessary to overwrite the private `$.predict()` method. +#' To deviate from the defaults, it is necessary to overwrite the private `$.encode_prediction()` +#' method, see section *Inheriting*. #' @param loss (`NULL` or [`TorchLoss`])\cr #' The loss to use for training. #' Defaults to MSE for regression and cross entropy for classification. @@ -71,7 +72,7 @@ #' ([`torch_tensor`], [`Task`], `list()`) -> `list()`\cr #' Take in the raw predictions from `self$network` (`predict_tensor`) and encode them into a #' format that can be converted to valid `mlr3` predictions using [`mlr3::as_prediction_data()`]. -#' This must take `self$predict_type` into account. +#' This method must take `self$predict_type` into account. #' #' While it is possible to add parameters by specifying the `param_set` construction argument, it is currently #' not possible to remove existing parameters, i.e. those listed in section *Parameters*. @@ -88,15 +89,6 @@ LearnerTorch = R6Class("LearnerTorch", initialize = function(id, task_type, param_set, properties, man, label, feature_types, optimizer = NULL, loss = NULL, packages = NULL, predict_types = NULL, callbacks = list()) { assert_choice(task_type, c("regr", "classif")) - predict_types = predict_types %??% switch(task_type, - regr = "response", - classif = c("response", "prob") - ) - loss = loss %??% switch(task_type, - classif = t_loss("cross_entropy"), - regr = t_loss("mse") - ) - optimizer = optimizer %??% t_opt("adam") learner_torch_initialize(self = self, private = private, super = super, task_type = task_type, @@ -118,6 +110,7 @@ LearnerTorch = R6Class("LearnerTorch", #' @field network ([`nn_module()`][torch::nn_module])\cr #' The network (only available after training). network = function(rhs) { + assert_ro_binding(rhs) if (is.null(self$state)) { stopf("Cannot access network before training.") } diff --git a/R/LearnerTorchAlexNet.R b/R/LearnerTorchAlexNet.R index d853fce5..7bba78a6 100644 --- a/R/LearnerTorchAlexNet.R +++ b/R/LearnerTorchAlexNet.R @@ -27,7 +27,6 @@ LearnerTorchAlexNet = R6Class("LearnerTorchAlexNet", pretrained = p_lgl(tags = c("required", "train")) ) param_set$values = list(pretrained = TRUE) - # TODO: Freezing --> maybe as a callback? super$initialize( task_type = task_type, id = paste0(task_type, ".alexnet"), diff --git a/R/LearnerTorchFeatureless.R b/R/LearnerTorchFeatureless.R index 42f69676..d51a2583 100644 --- a/R/LearnerTorchFeatureless.R +++ b/R/LearnerTorchFeatureless.R @@ -31,7 +31,6 @@ LearnerTorchFeatureless = R6Class("LearnerTorchFeatureless", label = "Featureless Torch Learner", param_set = ps(), properties = properties, - # TODO: This should have all feature types, and have prop"), feature_types = unname(mlr_reflections$task_feature_types), man = "mlr3torch::mlr_learners.torch_featureless", optimizer = optimizer, diff --git a/R/LearnerTorchImage.R b/R/LearnerTorchImage.R index 84ae4414..de9b34c3 100644 --- a/R/LearnerTorchImage.R +++ b/R/LearnerTorchImage.R @@ -48,7 +48,7 @@ LearnerTorchImage = R6Class("LearnerTorchImage", callbacks = list(), packages = c("torchvision", "magick"), man, properties = NULL, predict_types = NULL) { properties = properties %??% switch(task_type, - regr = c(), + regr = character(0), classif = c("twoclass", "multiclass") ) assert_param_set(param_set) diff --git a/R/learner_torch_methods.R b/R/learner_torch_methods.R index 0fa9919f..d3b5a2b8 100644 --- a/R/learner_torch_methods.R +++ b/R/learner_torch_methods.R @@ -26,6 +26,16 @@ learner_torch_initialize = function( label, callbacks ) { + predict_types = predict_types %??% switch(task_type, + regr = "response", + classif = c("response", "prob") + ) + loss = loss %??% switch(task_type, + classif = t_loss("cross_entropy"), + regr = t_loss("mse") + ) + optimizer = optimizer %??% t_opt("adam") + private$.optimizer = as_torch_optimizer(optimizer, clone = TRUE) private$.optimizer$param_set$set_id = "opt" diff --git a/man/mlr_learners_torch.Rd b/man/mlr_learners_torch.Rd index 8ec5824b..65c3828c 100644 --- a/man/mlr_learners_torch.Rd +++ b/man/mlr_learners_torch.Rd @@ -83,7 +83,7 @@ To change the predict types, the private \code{.encode_prediction()} method can (\code{\link{torch_tensor}}, \code{\link{Task}}, \code{list()}) -> \code{list()}\cr Take in the raw predictions from \code{self$network} (\code{predict_tensor}) and encode them into a format that can be converted to valid \code{mlr3} predictions using \code{\link[mlr3:as_prediction_data]{mlr3::as_prediction_data()}}. -This must take \code{self$predict_type} into account. +This method must take \code{self$predict_type} into account. } While it is possible to add parameters by specifying the \code{param_set} construction argument, it is currently @@ -204,7 +204,8 @@ The predict types. See \code{\link[=mlr_reflections]{mlr_reflections$learner_predict_types}} for available values. For regression, the default is \code{"response"}. For classification, this defaults to \code{"response"} and \code{"prob"}. -To deviate from the defaults, it is necessary to overwrite the private \verb{$.predict()} method.} +To deviate from the defaults, it is necessary to overwrite the private \verb{$.encode_prediction()} +method, see section \emph{Inheriting}.} \item{\code{callbacks}}{(\code{list()} of \code{\link{TorchCallback}}s)\cr The callbacks to use for training.