Skip to content

Commit

Permalink
fix issues from merging
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jul 11, 2023
1 parent 12e308d commit df0b7ae
Show file tree
Hide file tree
Showing 23 changed files with 145 additions and 226 deletions.
1 change: 0 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ Collate:
'TorchDescriptor.R'
'TorchOptimizer.R'
'bibentries.R'
'expectations.R'
'imageuri.R'
'learner_torch_methods.R'
'nn_graph.R'
Expand Down
57 changes: 47 additions & 10 deletions R/LearnerTorch.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#' @title Abstract Base Class for a Torch Learner
#' @title Base Class for Torch Learners
#'
#' @name mlr_learners_torch
#'
Expand All @@ -11,25 +11,37 @@
#' @template param_id
#' @template param_task_type
#' @template param_param_vals
#' @template param_optimizer
#' @template param_loss
#' @template param_param_set
#' @template param_properties
#' @template param_packages
#' @template param_predict_types
#' @template param_feature_types
#' @template param_man
#' @template param_label
#' @template param_callbacks
#' @param predict_types (`character()`)\cr
#' The predict types.
#' See [`mlr_reflections$learner_predict_types`][mlr_reflections] for available values.
#' For regression, the default is `"response"`.
#' For classification, this defaults to `"response"` and `"prob"`.
#' To deviate from the defaults, it is necessary to overwrite the private `$.predict()` method.
#' @param loss (`NULL` or [`TorchLoss`])\cr
#' The loss to use for training.
#' Defaults to MSE for regression and cross entropy for classification.
#' @param optimizer (`NULL` or [`TorchOptimizer`])\cr
#' The optimizer to use for training.
#' Defaults to adam.
#' @param callbacks (`list()` of [`TorchCallback`]s)\cr
#' The callbacks to use for training.
#' Defaults to an empty` list()`, i.e. no callbacks.
#'
#' @section State:
#' The state is a list with elements `network`, `optimizer`, `loss_fn` and `callbacks`.
#' The state is a list with elements `network`, `optimizer`, `loss_fn`, `callbacks` and `seed`.
#'
#' @template paramset_torchlearner
#'
#' @section Inheriting:
#' There are no seperate classes for classification and regression to inherit from.
#' Instead, the `task_type` must be specified in the initialize method.
#' Instead, the `task_type` must be specified as a construction argument.
#' Currently, only classification and regression are supported.
#'
#' When inheriting from this class, one should overload two private methods:
#'
Expand All @@ -51,7 +63,15 @@
#' * `.dataloader(task, param_vals)`\cr
#' ([`Task`], `list()`) -> [`torch::dataloader`]\cr
#' Create a dataloader from the task.
#' Needs to respect at least `batch_size` and `shuffle` (otherwise predictions are permuted).
#' Needs to respect at least `batch_size` and `shuffle` (otherwise predictions can be permuted).
#'
#' To change the predict types, the private `.encode_prediction()` method can be overwritten:
#'
#' * `.encode_prediction(predict_tensor, task, param_vals)`\cr
#' ([`torch_tensor`], [`Task`], `list()`) -> `list()`\cr
#' Take in the raw predictions from `self$network` (`predict_tensor`) and encode them into a
#' format that can be converted to valid `mlr3` predictions using [`mlr3::as_prediction_data()`].
#' This must take `self$predict_type` into account.
#'
#' While it is possible to add parameters by specifying the `param_set` construction argument, it is currently
#' not possible to remove existing parameters, i.e. those listed in section *Parameters*.
Expand Down Expand Up @@ -129,10 +149,27 @@ LearnerTorch = R6Class("LearnerTorch",
),
private = list(
.train = function(task) {
learner_torch_train(self, task)
param_vals = self$param_set$get_values(tags = "train")
param_vals$device = auto_device(param_vals$device)
if (param_vals$seed == "random") param_vals$seed = sample.int(10000000L, 1L)

with_torch_settings(seed = param_vals$seed, num_threads = param_vals$num_threads, {
learner_torch_train_worker(self, private, super, task, param_vals)
})
},
.predict = function(task) {
learner_torch_predict(self, task)
param_vals = self$param_set$get_values(tags = "predict")
param_vals$device = auto_device(param_vals$device)

with_torch_settings(seed = self$model$seed, num_threads = param_vals$num_threads, {
self$network$eval()
data_loader = private$.dataloader_predict(task, param_vals)
predict_tensor = torch_network_predict(self$network, data_loader)
private$.encode_prediction(predict_tensor, task, param_vals)
})
},
.encode_prediction = function(predict_tensor, task, param_vals) {
encode_prediction(predict_tensor, self$predict_type, task)
},
.network = function(task, param_vals) stop(".network must be implemented."),
# the dataloader gets param_vals that may be different from self$param_set$values, e.g.
Expand Down
5 changes: 3 additions & 2 deletions R/LearnerTorchAlexNet.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ LearnerTorchAlexNet = R6Class("LearnerTorchAlexNet",
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(task_type, optimizer = NULL, loss = NULL, callbacks = list()) {
param_set = ps(
pretrained = p_lgl(default = TRUE, tags = "train")
pretrained = p_lgl(tags = c("required", "train"))
)
param_set$values = list(pretrained = TRUE)
# TODO: Freezing --> maybe as a callback?
super$initialize(
task_type = task_type,
Expand All @@ -42,7 +43,7 @@ LearnerTorchAlexNet = R6Class("LearnerTorchAlexNet",
private = list(
.network = function(task, param_vals) {
nout = if (self$task_type == "regr") 1 else length(task$class_names)
if (param_vals$pretrained %??% TRUE) {
if (param_vals$pretrained) {
network = torchvision::model_alexnet(pretrained = TRUE)

network$classifier$`6` = torch::nn_linear(
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerTorchFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ LearnerTorchFeatureless = R6Class("LearnerTorchFeatureless",
label = "Featureless Torch Learner",
param_set = ps(),
properties = properties,
# TODO: This should have all feature types, and have properties missing
feature_types = c("integer", "numeric", "factor", "ordered"),
# TODO: This should have all feature types, and have prop"),
feature_types = unname(mlr_reflections$task_feature_types),
man = "mlr3torch::mlr_learners.torch_featureless",
optimizer = optimizer,
loss = loss,
Expand Down
5 changes: 0 additions & 5 deletions R/LearnerTorchImage.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ LearnerTorchImage = R6Class("LearnerTorchImage",
regr = c(),
classif = c("twoclass", "multiclass")
)
predict_types = predict_types %??% switch(task_type,
regr = "response",
classif = c("response", "prob")
)

assert_param_set(param_set)
predefined_set = ps(
channels = p_int(1, tags = c("train", "predict", "required")),
Expand Down
31 changes: 19 additions & 12 deletions R/LearnerTorchMLP.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
#' @section Parameters:
#' Parameters from [`LearnerTorch`], as well as:
#'
#' * `activation` :: `character(1)`\cr
#' Activation function.
#' * `activation` :: `[nn_module]`\cr
#' The activation function. Is initialized to [`nn_relu`].
#' * `activation_args` :: named `list()`\cr
#' A named list with initialization arguments for the activation function.
#' This is intialized to an empty list.
#' * `layers` :: `integer(1)`\cr
#' The number of layers.
#' * `d_hidden` :: `numeric(1)`\cr
#' The dimension of the hidden layers.
#' * `p` :: `numeric(1)`\cr
#' The dropout probability.
#' Is initialized to `0.5`.
#'
#' @export
LearnerTorchMLP = R6Class("LearnerTorchMLP",
Expand All @@ -31,12 +33,19 @@ LearnerTorchMLP = R6Class("LearnerTorchMLP",
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(task_type, optimizer = NULL, loss = NULL, callbacks = list()) {
check_activation = crate(function(x) check_class(x, "nn_module"), .parent = topenv())
check_activation_args = crate(function(x) check_list(x, names = "unique"), .parent = topenv())
param_set = ps(
activation = p_fct(default = "relu", tags = "train", levels = mlr_reflections$torch$activations),
activation_args = p_uty(tags = "train", custom_check = check_list),
layers = p_int(lower = 0L, tags = c("train", "required")),
d_hidden = p_int(lower = 1L, tags = "train"),
p = p_dbl(default = 0.5, lower = 0, upper = 1, tags = "train")
layers = p_int(lower = 0L, tags = c("required", "train")),
d_hidden = p_int(lower = 1L, tags = c("required", "train")),
p = p_dbl(lower = 0, upper = 1, tags = c("required", "train")),
activation = p_uty(tags = c("required", "train"), custom_check = check_activation),
activation_args = p_uty(tags = c("required", "train"), custom_check = check_activation_args)
)
param_set$set_values(
activation = nn_relu,
p = 0.5,
activation_args = list()
)
properties = switch(task_type,
regr = character(0),
Expand Down Expand Up @@ -76,8 +85,6 @@ LearnerTorchMLP = R6Class("LearnerTorchMLP",

make_mlp = function(task, activation, layers, d_hidden, p, activation_args) {
task_type = task$task_type
activation = activation %??% "relu"
act = getFromNamespace(paste0("nn_", activation), ns = "torch")
layers = layers
d_hidden = d_hidden
if (layers > 0) assert_true(!is.null(d_hidden))
Expand All @@ -94,19 +101,19 @@ make_mlp = function(task, activation, layers, d_hidden, p, activation_args) {
}

# This way, dropout_args will have length 0 if p is `NULL`
dropout_args = list(p = p)
dropout_args = list()
dropout_args$p = p

modules = list(
nn_linear(length(task$feature_names), d_hidden),
invoke(act, .args = activation_args),
invoke(activation, .args = activation_args),
invoke(nn_dropout, .args = dropout_args)
)

for (i in seq_len(layers - 1L)) {
modules = c(modules, list(
nn_linear(d_hidden, d_hidden),
invoke(act, .args = activation_args),
invoke(activation, .args = activation_args),
invoke(nn_dropout, .args = dropout_args)
))
}
Expand Down
24 changes: 11 additions & 13 deletions R/LearnerTorchModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
#'
#' @description
#' Create a torch learner from an instantiated [`nn_module()`].
#' This is learner is used internally by [`PipeOpTorchModelClassif`] and [`PipeOpTorchModelRegr`].
#'
#' For classification, the output of the network must be the scores (before the softmax).
#'
#' @template param_task_type
Expand All @@ -18,16 +16,18 @@
#' @template param_loss
#' @template param_callbacks
#' @template param_packages
#' @param feature_types (`character()`)\cr
#' @param feature_types (`NULL` or `character()`)\cr
#' The feature types. Defaults to all available feature types.
#' @template param_properties
#'
#' @param properties (`NULL` or `character()`)\cr
#' The properties of the learner.
#' Defaults to all available properties for the given task type.
#' @section Parameters: See [`LearnerTorch`]
#' @family Learner
#' @family Graph Network
#' @include LearnerTorch.R
#' @export
#' @examples
#' # We show the learner using a classification task
#'
#' # The iris task has 4 features and 3 classes
#' network = nn_linear(4, 3)
Expand All @@ -48,8 +48,6 @@
#' epochs = 1
#' )
#'
#'
#'
#' # A simple train-predict
#' ids = partition(task)
#' learner$train(task, ids$train)
Expand All @@ -69,11 +67,11 @@ LearnerTorchModel = R6Class("LearnerTorchModel",
} else {
assert_subset(feature_types, mlr_reflections$task_feature_types)
}
properties = properties %??% switch(task_type,
regr = character(),
classif = c("twoclass", "multiclass"),
stopf("Invalid task type '%s'.", task_type)
)
if (is.null(properties)) {
properties = mlr_reflections$learner_properties[[task_type]]
} else {
properties = assert_subset(properties, mlr_reflections$learner_properties[[task_type]])
}
super$initialize(
id = paste0(task_type, ".model"),
task_type = task_type,
Expand All @@ -97,7 +95,7 @@ LearnerTorchModel = R6Class("LearnerTorchModel",
task,
feature_ingress_tokens = private$.ingress_tokens,
target_batchgetter = target_batchgetter(self$task_type),
device = param_vals$device %??% self$param_set$default$device
device = param_vals$device
)
},
.network_stored = NULL,
Expand Down
3 changes: 2 additions & 1 deletion R/PipeOpTorchModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
#' @template paramset_torchlearner
#'
#' @section Internals:
#' A [`LearnerTorchModel`] is created by calling [`model_descriptor_to_learner()`].
#' A [`LearnerTorchModel`] is created by calling [`model_descriptor_to_learner()`] on the
#' provided [`ModelDescriptor`] that is received through the input channel.
#' Then the parameters are set according to the parameters specified in `PipeOpTorchModel` and
#' its '$train()` method is called on the [`Task`] stored in the [`ModelDescriptor`].
#'
Expand Down
5 changes: 0 additions & 5 deletions R/expectations.R

This file was deleted.

30 changes: 0 additions & 30 deletions R/learner_torch_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,6 @@ normalize_to_list = function(x) {
x
}

learner_torch_train = function(self, task) {
private = self$.__enclos_env__$private
super = self$.__enclos_env__$super
param_vals = self$param_set$get_values(tags = "train")

param_vals$device = auto_device(param_vals$device)
if (param_vals$seed == "random") param_vals$seed = sample.int(10000000L, 1L)

with_torch_settings(seed = param_vals$seed, num_threads = param_vals$num_threads, {
learner_torch_train_worker(self, private, super, task, param_vals)
})
}

learner_torch_initialize = function(
self,
private,
Expand Down Expand Up @@ -347,20 +334,3 @@ measure_prediction = function(pred_tensor, measures, task, row_ids) {
}
)
}

# Here are the standard methods that are shared between all the TorchLearners
learner_torch_predict = function(self, task) {
private = self$.__enclos_env__$private
model = self$state$model
network = model$network
param_vals = self$param_set$get_values(tags = "predict")

param_vals$device = auto_device(param_vals$device)

with_torch_settings(seed = self$model$seed, num_threads = param_vals$num_threads, {
network$eval()
data_loader = private$.dataloader_predict(task, param_vals)
prediction = torch_network_predict(network, data_loader)
encode_prediction(prediction, self$predict_type, task)
})
}
8 changes: 1 addition & 7 deletions R/nn_graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,10 @@ model_descriptor_to_learner = function(model_descriptor) {
)
network$reset_parameters()

# FIXME: How to do this better?
properties = switch(task_type,
regr = character(),
classif = c("twoclass", "multiclass")
)

learner = LearnerTorchModel$new(
task_type = task_type,
network = network,
properties = properties,
properties = NULL,
ingress_tokens = ingress_tokens,
optimizer = optimizer,
loss = loss,
Expand Down
Loading

0 comments on commit df0b7ae

Please sign in to comment.