Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cb #111

Merged
merged 19 commits into from
Jul 7, 2023
Merged

Cb #111

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@
^\.lintr$
^info$
^README_files$
^TODO\.md$
^\.ignore$
3 changes: 3 additions & 0 deletions .github/workflows/r-cmd-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ on:
pull_request:
branches:
- main
schedule:
- cron: '0 4 * * 1'


name: r-cmd-check

Expand Down
13 changes: 6 additions & 7 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Authors@R:
comment = c(ORCID = "0000-0001-6002-6980")),
person(given = "Lukas",
family = "Burk",
role = "aut",
role = "ctb",
email = "[email protected]",
comment = c(ORCID = "0000-0001-7528-3795")),
person(given = "Martin",
Expand Down Expand Up @@ -63,12 +63,12 @@ Encoding: UTF-8
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.2.3
Collate:
'CallbackTorch.R'
'CallbackSet.R'
'zzz.R'
'TorchCallback.R'
'CallbackTorchCheckpoint.R'
'CallbackTorchHistory.R'
'CallbackTorchProgress.R'
'CallbackSetCheckpoint.R'
'CallbackSetHistory.R'
'CallbackSetProgress.R'
'ContextTorch.R'
'LearnerTorch.R'
'LearnerClassifTorchImage.R'
Expand Down Expand Up @@ -99,10 +99,9 @@ Collate:
'PipeOpTorchReshape.R'
'PipeOpTorchSoftmax.R'
'TaskClassif_tiny_imagenet.R'
'TorchDescriptor.R'
'TorchOptimizer.R'
'TorchWrapper.R'
'bibentries.R'
'helper.R'
'imageuri.R'
'learner_torch_methods.R'
'nn_graph.R'
Expand Down
12 changes: 6 additions & 6 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ S3method(t_opt,"NULL")
S3method(t_opt,character)
S3method(t_opts,"NULL")
S3method(t_opts,character)
export(CallbackTorch)
export(CallbackTorchCheckpoint)
export(CallbackTorchHistory)
export(CallbackTorchProgress)
export(CallbackSet)
export(CallbackSetCheckpoint)
export(CallbackSetHistory)
export(CallbackSetProgress)
export(ContextTorch)
export(LearnerClassifAlexNet)
export(LearnerClassifMLP)
Expand Down Expand Up @@ -110,17 +110,17 @@ export(PipeOpTorchTanhShrink)
export(PipeOpTorchThreshold)
export(PipeOpTorchUnsqueeze)
export(TorchCallback)
export(TorchDescriptor)
export(TorchIngressToken)
export(TorchLoss)
export(TorchOptimizer)
export(TorchWrapper)
export(as_torch_callback)
export(as_torch_callbacks)
export(as_torch_loss)
export(as_torch_optimizer)
export(batchgetter_categ)
export(batchgetter_num)
export(callback_torch)
export(callback_set)
export(imageuri)
export(mlr3torch_callbacks)
export(mlr3torch_losses)
Expand Down
96 changes: 63 additions & 33 deletions R/CallbackTorch.R → R/CallbackSet.R
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
#' @title Base Class for Torch Callbacks
#' @title Base Class for Callbacks
#'
#' @name mlr_callbacks_torch
#' @name mlr_callback_set
#'
#' @description
#' Base class from which Callbacks should inherit.
#' Base class from which callbacks should inherit (see section *Inheriting*).
#' A callback set is a collection of functions that are executed at different stages of the training loop.
#' They can be used to gain more control over the training process of a neural network without
#' having to write everything from scratch.
#'
#' For each available stage (see section *Stages*) a public method `$on_<stage>(ctx)` can be defined.
#' This must be an function with argument `ctx`, which is a [`ContextTorch`].
#' When a learner is trained, at a specific `<stage>`, the `$on_<stage>(ctx)` method of the callback is
#' executed, where the `ctx` represents the current state of the training loop.
#' Different stages of a callback can communicate with each other by assigning values to `$self`.
#' When used a in torch learner, the `CallbackSet` is wrapped in a [`TorchCallback`].
#' The latters parameter set represents the arguments of the [`CallbackSet`]'s `$initialize()` method.
#'
#' When used a in torch learner, the `CallbackTorch` is wrapped in a [`TorchCallback`].
#' The latters parameter set represents the arguments of the [`CallbackTorch`]'s `$initialize()` method and can
#' be specified in the learner. The callback is then initialized in the beginning of the training loop.
#' @section Inheriting:
#' For each available stage (see section *Stages*) a public method `$on_<stage>()` can be defined.
#' The evaluation context (a [`ContextTorch`]) can be accessed via `self$ctx`, which contains
#' the current state of the training loop.
#' This context is assigned at the beginning of the training loop and removed afterwards.
#' Different stages of a callback can communicate with each other by assigning values to `$self`.
#'
#' For creating custom callbacks, the function [`torch_callback()`] is recommended, which creates a
#' [`CallbackTorch`] and then wraps it in a [`TorchCallback`].
#' `CallbackSet` and then wraps it in a [`TorchCallback`].
#' To create a `CallbackSet` the convenience function [`callback_set()`] can be used.
#' These functions perform checks such as that the stages are not accidentally misspelled.
#'
#' @section Stages:
#' * `begin` :: Run before the training loop begins.
Expand All @@ -33,18 +36,35 @@
#' * `end` :: Run at last, using `on.exit()`.
#' @family Callback
#' @export
CallbackTorch = R6Class("CallbackTorch",
CallbackSet = R6Class("CallbackSet",
lock_objects = FALSE,
cloneable = FALSE
public = list(
#' @field ctx ([`ContextTorch`] or `NULL`)\cr
#' The evaluation context for the callback.
#' This field should always be `NULL` except during the `$train()` call of the torch learner.
ctx = NULL
),
private = list(
deep_clone = function(name, value) {
if (name == "ctx" && !is.null(value)) {
stopf("CallbackSet instances must never be cloned unless the ctx is NULL.")
} else {
value
}
}
)
)

#' @title Create a Callback Torch
#' @title Create a Set of Callbacks for Torch
#'
#' @description
#' Creates an `R6ClassGenerator` inheriting from [`CallbackTorch`].
#' Creates an `R6ClassGenerator` inheriting from [`CallbackSet`].
#' Additionally performs checks such as that the stages are not accidentally misspelled.
#' To create a [`TorchCallback`] use [`torch_callback()`].
#'
#' In order for the resulting class to be cloneable, the private method `$deep_clone()` must be
#' provided.
#'
#' @param classname (`character(1)`)\cr
#' The class name.
#' @param on_begin,on_end,on_epoch_begin,on_before_valid,on_epoch_end,on_batch_begin,on_batch_end,on_after_backward,on_batch_valid_begin,on_batch_valid_end (`function`)\cr
Expand All @@ -57,12 +77,17 @@ CallbackTorch = R6Class("CallbackTorch",
#' The parent environment for the [`R6Class`].
#' @param inherit (`R6ClassGenerator`)\cr
#' From which class to inherit.
#' This class must either be [`CallbackTorch`] (default) or inherit from it.
#'
#' This class must either be [`CallbackSet`] (default) or inherit from it.
#' @param lock_objects (`logical(1)`)\cr
#' Whether to lock the objects of the resulting [`R6Class`].
#' If `FALSE` (default), values can be freely assigned to `self` without declaring them in the
#' class definition.
#' @family Callback
#'
#' @return [`CallbackSet`]
#'
#' @export
callback_torch = function(
callback_set = function(
classname,
# training
on_begin = NULL,
Expand All @@ -78,20 +103,21 @@ callback_torch = function(
on_batch_valid_end = NULL,
# other methods
initialize = NULL,
public = NULL, private = NULL, active = NULL, parent_env = parent.frame(), inherit = CallbackTorch
public = NULL, private = NULL, active = NULL, parent_env = parent.frame(), inherit = CallbackSet,
lock_objects = FALSE
) {
assert_true(startsWith(classname, "CallbackTorch"))
assert_true(startsWith(classname, "CallbackSet"))
more_public = list(
on_begin = assert_function(on_begin, args = "ctx", null.ok = TRUE),
on_end = assert_function(on_end, args = "ctx", null.ok = TRUE),
on_epoch_begin = assert_function(on_epoch_begin, args = "ctx", null.ok = TRUE),
on_before_valid = assert_function(on_before_valid, args = "ctx", null.ok = TRUE),
on_epoch_end = assert_function(on_epoch_end, args = "ctx", null.ok = TRUE),
on_batch_begin = assert_function(on_batch_begin, args = "ctx", null.ok = TRUE),
on_batch_end = assert_function(on_batch_end, args = "ctx", null.ok = TRUE),
on_after_backward = assert_function(on_after_backward, args = "ctx", null.ok = TRUE),
on_batch_valid_begin = assert_function(on_batch_valid_begin, args = "ctx", null.ok = TRUE),
on_batch_valid_end = assert_function(on_batch_valid_end, args = "ctx", null.ok = TRUE)
on_begin = assert_function(on_begin, nargs = 0, null.ok = TRUE),
on_end = assert_function(on_end, nargs = 0, null.ok = TRUE),
on_epoch_begin = assert_function(on_epoch_begin, nargs = 0, null.ok = TRUE),
on_before_valid = assert_function(on_before_valid, nargs = 0, null.ok = TRUE),
on_epoch_end = assert_function(on_epoch_end, nargs = 0, null.ok = TRUE),
on_batch_begin = assert_function(on_batch_begin, nargs = 0, null.ok = TRUE),
on_batch_end = assert_function(on_batch_end, nargs = 0, null.ok = TRUE),
on_after_backward = assert_function(on_after_backward, nargs = 0, null.ok = TRUE),
on_batch_valid_begin = assert_function(on_batch_valid_begin, nargs = 0, null.ok = TRUE),
on_batch_valid_end = assert_function(on_batch_valid_end, nargs = 0, null.ok = TRUE)
)

assert_function(initialize, null.ok = TRUE)
Expand All @@ -111,14 +137,18 @@ callback_torch = function(
paste(paste0("'", invalid_stages, "'"), collapse = ", ")
)
}
cloneable = test_function(private$deep_clone, args = c("name", "value"))

assert_list(private, null.ok = TRUE, names = "unique")
assert_list(active, null.ok = TRUE, names = "unique")
assert_environment(parent_env)
assert_inherits_classname(inherit, "CallbackTorch")
assert_inherits_classname(inherit, "CallbackSet")

more_public = Filter(function(x) !is.null(x), more_public)
parent_env_shim = new.env(parent = parent_env)
parent_env_shim$inherit = inherit
R6::R6Class(classname = classname, inherit = inherit, public = c(public, more_public),
private = private, active = active, parent_env = parent_env_shim, lock_objects = FALSE)
private = private, active = active, parent_env = parent_env_shim, lock_objects = lock_objects,
cloneable = cloneable
)
}
37 changes: 18 additions & 19 deletions R/CallbackTorchCheckpoint.R → R/CallbackSetCheckpoint.R
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
#' @title Callback Torch Checkpoint
#' @title Checkpoint Callback
#'
#' @name mlr_callbacks_torch.checkpoint
#' @name mlr_callback_set.checkpoint
#'
#' @description
#' Saves the model during training.
#' @param path (`character(1)`)\cr
#' The path to a folder where the models are saved. This path must not exist before.
#' @param freq (`integer(1)`)\cr
#' The frequency how often the model is saved (epoch frequency).
#'
#' @family Callback
#' @export
CallbackTorchCheckpoint = R6Class("CallbackTorchCheckpoint",
inherit = CallbackTorch,
#' @include CallbackSet.R
CallbackSetCheckpoint = R6Class("CallbackSetCheckpoint",
inherit = CallbackSet,
lock_objects = FALSE,
# TODO: This should also save the learner itself
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#' @param path (`character(1)`)\cr
#' The path to a folder where the models are saved. This path must not exist before.
#' @param freq (`integer(1)`)\cr
#' The frequency how often the model is saved (epoch frequency).
initialize = function(path, freq) {
# TODO: Maybe we want to be able to give gradient steps here instead of epochs?
assert_path_for_output(path)
Expand All @@ -27,34 +28,32 @@ CallbackTorchCheckpoint = R6Class("CallbackTorchCheckpoint",
},
#' @description
#' Saves the network state dict.
#' @param ctx [ContextTorch]
on_epoch_end = function(ctx) {
if ((ctx$epoch %% self$freq) == 0) {
torch::torch_save(ctx$network, file.path(self$path, paste0("network", ctx$epoch, ".pt")))
on_epoch_end = function() {
if ((self$ctx$epoch %% self$freq) == 0) {
torch::torch_save(self$ctx$network, file.path(self$path, paste0("network", self$ctx$epoch, ".pt")))
}
},
#' @description
#' Saves the final network.
#' @param ctx [ContextTorch]
on_end = function(ctx) {
path = file.path(self$path, paste0("network", ctx$epoch, ".pt"))
on_end = function() {
path = file.path(self$path, paste0("network", self$ctx$epoch, ".pt"))
if (!file.exists(path)) { # no need to save the last network twice if it was already saved.
torch::torch_save(ctx$network, path)
torch::torch_save(self$ctx$network, path)
}
}
)
)

#' @include TorchCallback.R CallbackTorch.R
#' @include TorchCallback.R
mlr3torch_callbacks$add("checkpoint", function() {
TorchCallback$new(
callback_generator = CallbackTorchCheckpoint,
callback_generator = CallbackSetCheckpoint,
param_set = ps(
path = p_uty(),
freq = p_int(lower = 1L)
),
id = "checkpoint",
label = "Checkpoint",
man = "mlr3torch::mlr_callbacks_torch.checkpoint"
man = "mlr3torch::mlr_callback_set.checkpoint"
)
})
Loading