Skip to content

Commit

Permalink
refactor learner (only one LearnerTorch class)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jun 30, 2023
1 parent ae8b405 commit 9c6f6bd
Show file tree
Hide file tree
Showing 58 changed files with 811 additions and 1,432 deletions.
7 changes: 4 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ Collate:
'CallbackTorchProgress.R'
'ContextTorch.R'
'LearnerTorch.R'
'LearnerClassifTorchImage.R'
'LearnerClassifAlexNet.R'
'LearnerMLP.R'
'LearnerTorchImage.R'
'LearnerTorchAlexNet.R'
'LearnerTorchFeatureless.R'
'LearnerTorchMLP.R'
'LearnerTorchModel.R'
'ModelDescriptor.R'
'PipeOpModule.R'
Expand Down Expand Up @@ -115,6 +115,7 @@ Collate:
'learner_torch_methods.R'
'nn_graph.R'
'paramset_torchlearner.R'
'rd_info.R'
'reset_last_layer.R'
'task_dataset.R'
'utils.R'
15 changes: 6 additions & 9 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,12 @@ export(CallbackTorchCheckpoint)
export(CallbackTorchHistory)
export(CallbackTorchProgress)
export(ContextTorch)
export(LearnerClassifAlexNet)
export(LearnerClassifMLP)
export(LearnerClassifTorch)
export(LearnerClassifTorchFeatureless)
export(LearnerClassifTorchImage)
export(LearnerClassifTorchModel)
export(LearnerRegrMLP)
export(LearnerRegrTorch)
export(LearnerRegrTorchModel)
export(LearnerTorch)
export(LearnerTorchAlexNet)
export(LearnerTorchFeatureless)
export(LearnerTorchImage)
export(LearnerTorchMLP)
export(LearnerTorchModel)
export(PipeOpModule)
export(PipeOpTorch)
export(PipeOpTorchAvgPool1D)
Expand Down
203 changes: 72 additions & 131 deletions R/LearnerTorch.R
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
#' @title Abstract Base Class for a Torch Classification Learner
#' @title Abstract Base Class for a Torch Learner
#'
#' @name mlr_learners_classif.torch
#' @name mlr_learners_torch
#'
#' @description
#' This base class provides the basic functionality for training and prediction of a neural network.
#' All torch classifiction learners should inherit from the respective class, i.e.
#' [`LearnerClassifTorch`] for classification and [`LearnerRegrTorch`] for regression.
#' All torch learners should inherit from this class.
#'
#' It also allows to hook into the training loop via a callback mechanism.
#'
#' @template param_id
#' @template param_task_type
#' @template param_param_vals
#' @template param_optimizer
#' @template param_loss
#' @template param_param_set
#' @template param_properties
#' @template param_packages
#' @template param_predict_types
#' @template param_feature_types
#' @template param_man
#' @template param_label
#' @template param_callbacks
#'
#' @section State:
#' The state is a list with elements `network`, `optimizer`, `loss_fn` and `callbacks`.
#'
#' @template paramset_torchlearner
#'
#' @section Inheriting:
#' There are no seperate classes for classification and regression to inherit from.
#' Instead, the `task_type` must be specified in the initialize method.
#'
#' When inheriting from this class, one should overload two private methods:
#'
#' * `.network(task, param_vals)`\cr
Expand Down Expand Up @@ -43,113 +59,25 @@
#'
#' @family Learner
#' @export
LearnerClassifTorch = R6Class("LearnerClassifTorch",
inherit = LearnerClassif,
LearnerTorch = R6Class("LearnerTorch",
inherit = Learner,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
#' @template param_id
#' @template param_param_vals
#' @template param_optimizer
#' @template param_loss
#' @template param_param_set
#' @template param_properties
#' @template param_packages
#' @template param_predict_types
#' @template param_feature_types
#' @template param_man
#' @template param_label
#' @template param_callbacks
initialize = function(id, optimizer, loss, param_set, properties = c("twoclass", "multiclass"), packages = character(0),
predict_types = c("response", "prob"), feature_types, man, label, callbacks = list()) {

learner_torch_initialize(self = self, private = private, super = super,
task_type = "classif",
id = id,
optimizer = optimizer,
loss = loss,
param_set = param_set,
properties = properties,
packages = packages,
predict_types = predict_types,
feature_types = feature_types,
man = man,
label = label,
callbacks = callbacks
initialize = function(id, task_type, param_set, properties, man, label, feature_types,
optimizer = NULL, loss = NULL, packages = NULL, predict_types = NULL, callbacks = list()) {
assert_choice(task_type, c("regr", "classif"))
predict_types = predict_types %??% switch(task_type,
regr = "response",
classif = c("response", "prob")
)
}
),
private = list(
.train = function(task) {
learner_torch_train(self, task)
},
.predict = function(task) {
learner_torch_predict(self, task)
},
.network = function(task, param_vals) stop(".network must be implemented."),
# the dataloader gets param_vals that may be different from self$param_set$values, e.g.
# when the dataloader for validation data is loaded, `shuffle` is set to FALSE.
.dataloader = function(task, param_vals) {
dataloader(
private$.dataset(task, param_vals),
batch_size = param_vals$batch_size %??% self$param_set$default$batch_size,
shuffle = param_vals$shuffle %??% self$param_set$default$shuffle
loss = loss %??% switch(task_type,
classif = t_loss("cross_entropy"),
regr = t_loss("mse")
)
},
.dataset = function(task, param_vals) stop(".dataset must be implemented."),
.optimizer = NULL,
.loss = NULL,
.param_set_base = NULL,
.callbacks = NULL,
deep_clone = function(name, value) deep_clone(self, private, super, name, value)
),
active = list(
#' @field network ([`nn_module()`][torch::nn_module])\cr
#' The network (only available after training).
network = function(rhs) learner_torch_network(self, rhs),
#' @field param_set ([`ParamSet`])\cr
#' The parameter set
param_set = function(rhs) learner_torch_param_set(self, rhs),
#' @field history ([`CallbackTorchHistory`])\cr
#' Shortcut for `learner$model$callbacks$history`.
history = function(rhs) learner_torch_history(self, rhs)
)
)
optimizer = optimizer %??% t_opt("adam")


#' @title Abstract Base Class for a Torch Regression Learner
#'
#' @name mlr_learners_regr.torch
#'
#' @description
#' This base class provides the basic functionality for training and prediction of a neural network.
#' All torch regression learners should inherit from the respective subclass.
#'
#' @inheritSection mlr_learners_classif.torch State
#' @inheritSection mlr_learners_classif.torch Parameters
#' @inheritSection mlr_learners_classif.torch Inheriting
#'
#' @family Learner
#' @export
LearnerRegrTorch = R6Class("LearnerRegrTorch",
inherit = LearnerRegr,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
#' @template param_id
#' @template param_param_vals
#' @template param_optimizer
#' @template param_loss
#' @template param_param_set
#' @template param_properties
#' @template param_packages
#' @template param_predict_types
#' @template param_feature_types
#' @template param_man
#' @template param_label
#' @template param_callbacks
initialize = function(id, optimizer, loss, param_set, properties = character(0), packages = character(0),
predict_types = "response", feature_types, man, label, callbacks = list()) {
learner_torch_initialize(self = self, private = private, super = super,
task_type = "regr",
task_type = task_type,
id = id,
optimizer = optimizer,
loss = loss,
Expand Down Expand Up @@ -186,38 +114,51 @@ LearnerRegrTorch = R6Class("LearnerRegrTorch",
.loss = NULL,
.param_set_base = NULL,
.callbacks = NULL,
deep_clone = function(name, value) deep_clone(self, private, super, name, value)
deep_clone = function(name, value) {
private$.param_set = NULL # required to keep clone identical to original, otherwise tests get really ugly

if (name == "state" && !is.null(value)) {
# https://github.com/mlr-org/mlr3torch/issues/97
stopf("Deep clone of trained network is currently not supported.")
} else if (name == ".param_set") {
NULL
} else {
super$deep_clone(name, value)
}
}
),
active = list(
#' @field network ([`nn_module()`][torch::nn_module])\cr
#' The network (only available after training).
network = function(rhs) learner_torch_network(self, rhs),
network = function(rhs) {
if (is.null(self$state)) {
stopf("Cannot access network before training.")
}
self$state$model$network
},
#' @field param_set ([`ParamSet`])\cr
#' The parameter set
param_set = function(rhs) learner_torch_param_set(self, rhs),
param_set = function(rhs) {
if (is.null(private$.param_set)) {
private$.param_set = ParamSetCollection$new(c(
list(private$.param_set_base, private$.optimizer$param_set, private$.loss$param_set),
map(private$.callbacks, "param_set"))
)
}
private$.param_set
},
#' @field history ([`CallbackTorchHistory`])\cr
#' Shortcut for `learner$model$callbacks$history`.
history = function(rhs) learner_torch_history(self, rhs)
#' Shortcut for `learner$model$callbacks$history`.
history = function(rhs) {
assert_ro_binding(rhs)
if (is.null(self$state)) {
stopf("Cannot access history before training.")
}
if (is.null(self$model$callbacks$history)) {
warningf("No history found. Did you specify t_clbk(\"history\") during construction?")
return(NULL)
}
self$model$callbacks$history
}
)
)


deep_clone = function(self, private, super, name, value) {
private$.param_set = NULL # required to keep clone identical to original, otherwise tests get really ugly

if (name == "state") {
# https://github.com/mlr-org/mlr3torch/issues/97
if (!is.null(value)) {
stopf("Deep clone of trained network is currently not supported.")
} else {
# Note that private methods are available in super.
super$deep_clone(name, value)
}
} else if (name == ".param_set") {
# Otherwise the value$clone() is called on NULL which errs
NULL
} else {
# Note that private methods are available in super.
super$deep_clone(name, value)
}
}
33 changes: 17 additions & 16 deletions R/LearnerClassifAlexNet.R → R/LearnerTorchAlexNet.R
Original file line number Diff line number Diff line change
@@ -1,37 +1,37 @@
# TODO
#' @title AlexNet Image Classifier
#'
#' @templateVar id classif.alexnet
#' @template params_learner
#' @templateVar name alexnet
#' @templateVar task_types classif
#' @template learner
#' @template params_learner
#'
#' @description
#' Historic convolutional network for image classification.
#' Historic convolutional neural network for image classification.
#'
#' @section Parameters:
#' Parameters from [`LearnerClassifTorchImage`] and
#' Parameters from [`LearnerTorchImage`] and
#'
#' * `pretrained` :: `logical(1)`\cr
#' Whether to use the pretrained model.
#'
#' @references `r format_bib("krizhevsky2017imagenet")`
#' @include LearnerClassifTorchImage.R
#' @include LearnerTorchImage.R
#' @export
#' @examples
#' learner = lrn("classif.alexnet")
#' learner$param_set
LearnerClassifAlexNet = R6Class("LearnerClassifAlexNet",
inherit = LearnerClassifTorchImage,
LearnerTorchAlexNet = R6Class("LearnerTorchAlexNet",
inherit = LearnerTorchImage,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(optimizer = t_opt("adam"), loss = t_loss("cross_entropy"), callbacks = list()) {
initialize = function(task_type, optimizer = NULL, loss = NULL, callbacks = list()) {
param_set = ps(
pretrained = p_lgl(default = TRUE, tags = "train")
)
# TODO: Freezing --> maybe as a callback?
super$initialize(
id = "classif.alexnet",
task_type = task_type,
id = paste0(task_type, ".alexnet"),
param_set = param_set,
man = "mlr3torch::mlr_learners_classif.alexnet",
man = "mlr3torch::mlr_learners.alexnet",
optimizer = optimizer,
loss = loss,
callbacks = callbacks,
Expand All @@ -41,21 +41,22 @@ LearnerClassifAlexNet = R6Class("LearnerClassifAlexNet",
),
private = list(
.network = function(task, param_vals) {
nout = if (self$task_type == "regr") 1 else length(task$class_names)
if (param_vals$pretrained %??% TRUE) {
network = torchvision::model_alexnet(pretrained = TRUE)

network$classifier$`6` = torch::nn_linear(
in_features = network$classifier$`6`$in_features,
out_features = length(task$class_names),
out_features = nout,
bias = TRUE
)
return(network)
}

torchvision::model_alexnet(pretrained = FALSE, num_classes = length(task$class_names))
torchvision::model_alexnet(pretrained = FALSE, num_classes = nout)
}
)
)

#' @include zzz.R
register_learner("classif.alexnet", LearnerClassifAlexNet)
register_learner("classif.alexnet", LearnerTorchAlexNet)
Loading

0 comments on commit 9c6f6bd

Please sign in to comment.