Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caching Prototype #382

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Imports:
mlr3misc (>= 0.1.4),
paradox,
R6,
R.cache,
withr
Suggests:
ggplot2,
Expand Down
88 changes: 86 additions & 2 deletions R/Graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@
#' (and therefore their `$param_set$values`) and a hash of `$edges`.
#' * `keep_results` :: `logical(1)` \cr
#' Whether to store intermediate results in the [`PipeOp`]'s `$.result` slot, mostly for debugging purposes. Default `FALSE`.
#' * `cache` :: `logical(1)` \cr
#' Whether to cache individual [`PipeOp`]'s during "train" and "predict". Default `FALSE`.
#' Caching is performed using the [`R.cache`](R.cache::R.cache) package.
#' Caching can be disabled/enabled globally using `getOption("R.cache.enabled", TRUE)`.
#' By default, files are cached in `R.cache::getCacheRootPath()`.
#' For more information on how to set the cache path or retrieve cached items please consider
#' the [`R.cache`](R.cache::R.cache) documentation.
#' Caching can be fine-controlled for each [`PipeOp`] by adjusting individual [`PipeOp`]'s
#' `cache`, `cache_state` and `stochastic` fields.
#'
#' @section Methods:
#' * `ids(sorted = FALSE)` \cr
Expand Down Expand Up @@ -407,6 +416,13 @@ Graph = R6Class("Graph",
} else {
map(self$pipeops, "state")
}
},
cache = function(val) {
if (!missing(val)) {
private$.cache = assert_flag(val)
} else {
private$.cache
}
}
),

Expand All @@ -419,7 +435,8 @@ Graph = R6Class("Graph",
value
)
},
.param_set = NULL
.param_set = NULL,
.cache = FALSE
)
)

Expand Down Expand Up @@ -539,7 +556,7 @@ graph_reduce = function(self, input, fun, single_input) {
input = input_tbl$payload
names(input) = input_tbl$name

output = op[[fun]](input)
output = cached_pipeop_eval(self, op, fun, input)
if (self$keep_results) {
op$.result = output
}
Expand Down Expand Up @@ -609,3 +626,70 @@ predict.Graph = function(object, newdata, ...) {
}
result
}

# Cached train/predict of a PipeOp.
# 1) Caching of a PipeOp only performed if graph and po have `cache = TRUE`,
# i.e both the Graph AND the PipeOp want to be cached.
# 2) Additonally caching is only performed if 'train' or 'predict' is not stochastic
# for a given PipeOp. This can be obtained from `.$stochastic` and can be set
# for each PipeOp.
# 3) During training we have two options
# Each PipeOp stores whether it wants to do I. or II. in `.$cache_state`.
# I. Cache only state:
# This is possible if the train transform is the same as the predict transform
# and predict is comparatively cheap (i.e. filters).
# II. Cache state and output
# (All other cases)

cached_pipeop_eval = function(self, op, fun, input) {

if (self$cache && op$cache) {
require_namespaces("R.cache")
cache_key = list(map_chr(input, get_hash), op$hash)
if (fun == "train") {
if (fun %nin% op$stochastic) {
# Two options:
# I. cache state (can predict on train set using state during train)
# II. do not cache state () (if I. is not possible)
if (op$cache_state) {
# only cache state (I.)
R.cache::evalWithMemoization({
op[[fun]](input)
state = op$state
}, key = cache_key)
# Set state if PipeOp was cached (and "train" was therefore not called)
if (is.null(op$state) && fun == "train") op$state = state
# We call "predict" on train inputs, this avoids storing the outputs
# during training on disk.
# This is only done for pipeops where 'cache_state' is TRUE.
return(cached_pipeop_eval(self, op, "predict", input))
} else {
# Otherwise we cache state and input (II.)
R.cache::evalWithMemoization({
result = list(output = op[[fun]](input), state = op$state)
}, key = cache_key)
# Set state if PipeOp was cached before (and thus no state was set)
if (is.null(op$state) && fun == "train") op$state = result$state
return(result$output)
}
}
} else if (fun == "predict" && !op$cache_state) {
# during predict, only cache if cache_state is FALSE and op is not stochastic.
if (fun %nin% op$stochastic) {
R.cache::evalWithMemoization(
{output = op[[fun]](input)},
key = cache_key)
return(output)
}
}
}
# No caching fallback, anything where we do not run into conditions above
return(op[[fun]](input))
}

get_hash = function(x) {
hash = try(x$hash, silent = TRUE)
if (inherits(hash, "try-error") || is.null(hash))
hash = digest(x, algo = "xxhash64")
hash
}
41 changes: 39 additions & 2 deletions R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,21 @@
#' If the [`Graph`]'s `$keep_results` flag is set to `TRUE`, then the intermediate Results of `$train()` and `$predict()`
#' are saved to this slot, exactly as they are returned by these functions. This is mainly for debugging purposes
#' and done, if requested, by the [`Graph`] backend itself; it should *not* be done explicitly by `private$.train()` or `private$.predict()`.
#' * `cache` :: `logical(1)` \cr
#' Whether to cache the [`PipeOp`]'s state and or output during "train" and "predict". Defaults to `TRUE`.
#' See the `cache` field in [`Graph`] for more detailed information on caching, as well as `cache_state` and
#' `stochastic` below.
#' * `cache_state` :: `logical(1)` \cr
#' Whether the [`PipeOp`]s behaviour during training is equal to behaviour during prediction
#' (other then setting a state). In this case, only the [`PipeOp`]s state is cached.
#' This avoids caching possibly large intermediate results.
#' Defaults to `TRUE`.
#' * `stochastic` :: `character` \cr
#' Whether a [`PipeOp`] is stochastic during `"train"`, `"predict"`, or not at all: `character(0)`.
#' Defaults to `character(0)` (deterministic). Stochastic [`PipeOp`]s are not cached during the
#' respective phase.
#' A [`PipeOp`] is only cached if it is deterministic.
#'
#'
#' @section Methods:
#' * `train(input)`\cr
Expand Down Expand Up @@ -254,7 +269,6 @@ PipeOp = R6Class("PipeOp",
if (is_noop(self$state)) {
stopf("Pipeop %s got NO_OP during train but no NO_OP during predict.", self$id)
}

input = check_types(self, input, "input", "predict")
output = private$.predict(input)
output = check_types(self, output, "output", "predict")
Expand Down Expand Up @@ -296,6 +310,26 @@ PipeOp = R6Class("PipeOp",
hash = function() {
digest(list(class(self), self$id, self$param_set$values),
algo = "xxhash64")
},
cache = function(val) {
if (!missing(val)) {
private$.cache = assert_flag(val)
} else {
private$.cache
}
},
cache_state = function(val) {
if (!missing(val)) {
stop("cache_state is read-only!")
}
private$.cache_state
},
stochastic = function(val) {
if (!missing(val)) {
private$.stochastic = assert_subset(val, c("train", "predict"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be read-only and set during initialization?

} else {
private$.stochastic
}
}
),

Expand All @@ -318,7 +352,10 @@ PipeOp = R6Class("PipeOp",
.predict = function(input) stop("abstract"),
.param_set = NULL,
.param_set_source = NULL,
.id = NULL
.id = NULL,
.cache = TRUE,
.cache_state = TRUE,
.stochastic = character(0)
)
)

Expand Down
3 changes: 2 additions & 1 deletion R/PipeOpBranch.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ PipeOpBranch = R6Class("PipeOpBranch",
ret = named_list(self$output$name, NO_OP)
ret[[self$param_set$values$selection]] = inputs[[1]]
ret
}
},
.cache = FALSE
)
)

Expand Down
14 changes: 13 additions & 1 deletion R/PipeOpChunk.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@ PipeOpChunk = R6Class("PipeOpChunk",
)
}
),
active = list(
stochastic = function(val) {
if (!missing(val)) {
assert_subset(val, c("train", "predict"))
private$.stochastic = val
} else {
if (self$param_set$values$shuffle) return("train")
character(0)
}
}
),
private = list(
.train = function(inputs) {
self$state = list()
Expand All @@ -88,7 +99,8 @@ PipeOpChunk = R6Class("PipeOpChunk",
},
.predict = function(inputs) {
rep(inputs, self$outnum)
}
},
.cache = FALSE
)
)

Expand Down
5 changes: 4 additions & 1 deletion R/PipeOpClassBalancing.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ PipeOpClassBalancing = R6Class("PipeOpClassBalancing",
task_filter_ex(task, new_ids)
},

.predict_task = identity
.predict_task = identity,
.cache = FALSE,
.stochastic = "train",
.cache_state = FALSE
)
)

Expand Down
3 changes: 2 additions & 1 deletion R/PipeOpCopy.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ PipeOpCopy = R6Class("PipeOpCopy",
},
.predict = function(inputs) {
rep_len(inputs, self$outnum)
}
},
.cache = FALSE
)
)

Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpImputeHist.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ PipeOpImputeHist = R6Class("PipeOpImputeHist",
}
feature[is.na(feature)] = sampled
feature
}
},
.cache = FALSE,
.stochastic = c("train", "predict")
)
)

Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpImputeSample.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ PipeOpImputeSample = R6Class("PipeOpImputeSample",
feature[is.na(feature)] = sample(model, outlen, replace = TRUE)
}
feature
}
},
.cache = FALSE,
.stochastic = c("train", "predict")
)
)

Expand Down
3 changes: 2 additions & 1 deletion R/PipeOpNOP.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ PipeOpNOP = R6Class("PipeOpNOP",

.predict = function(inputs) {
inputs
}
},
.cache = FALSE
)
)

Expand Down
29 changes: 29 additions & 0 deletions R/PipeOpProxy.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,35 @@ PipeOpProxy = R6Class("PipeOpProxy",
)
}
),
active = list(
cache = function(val) {
if (!missing(val)) {
self$param_set$values$content$cache = assert_flag(val)
} else {
self$param_set$values$content$cache
}
},
stochastic = function(val) {
if (!missing(val)) {
assert_subset(val, c("train", "predict"))
if (inherits(self$param_set$values$content, "Graph"))
stop("'stochastic' not be set when content is a graph!")
else
self$param_set$values$content$stochastic = val
} else {
if (inherits(self$param_set$values$content, "Graph")) return(character(0))
self$param_set$values$content$stochastic
}
},
cache_state = function(val) {
if (!missing(val)) {
stop("cache_state is read-only!")
} else {
if (inherits(self$param_set$values$content, "Graph")) return(TRUE)
self$param_set$values$content$cache_state
}
}
),
private = list(
.param_set = NULL,
.param_set_source = NULL,
Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpSmote.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ PipeOpSmote = R6Class("PipeOpSmote",
}
setnames(st, "class", task$target_names)
task$rbind(st)
}
},
.cache = FALSE,
.stochastic = "train"
)
)

Expand Down
4 changes: 2 additions & 2 deletions R/PipeOpSubsample.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ PipeOpSubsample = R6Class("PipeOpSubsample",
self$state = list()
task_filter_ex(task, keep)
},

.predict_task = identity
.predict_task = identity,
.cache_state = FALSE
)
)

Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpThreshold.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ PipeOpThreshold = R6Class("PipeOpThreshold",
}

list(prd$set_threshold(thr))
}
},
.cache = FALSE,
.cache_state = FALSE
)
)

Expand Down
3 changes: 2 additions & 1 deletion R/PipeOpUnbranch.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ PipeOpUnbranch = R6Class("PipeOpUnbranch",
},
.predict = function(inputs) {
filter_noop(inputs)
}
},
.cache = FALSE
)
)

Expand Down
Loading