Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace subgraph #560

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 118 additions & 1 deletion R/Graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#' @section Fields:
#' * `pipeops` :: named `list` of [`PipeOp`] \cr
#' Contains all [`PipeOp`]s in the [`Graph`], named by the [`PipeOp`]'s `$id`s.
#' * `edges` :: [`data.table`] with columns `src_id` (`character`), `src_channel` (`character`), `dst_id` (`character`), `dst_channel` (`character`)\cr
#' * `edges` :: [`data.table`] with columns `src_id` (`character`), `src_channel` (`character`), `dst_id` (`character`), `dst_channel` (`character`)\cr
#' Table of connections between the [`PipeOp`]s. A [`data.table`]. `src_id` and `dst_id` are `$id`s of [`PipeOp`]s that must be present in
#' the `$pipeops` list. `src_channel` and `dst_channel` must respectively be `$output` and `$input` channel names of the
#' respective [`PipeOp`]s.
Expand Down Expand Up @@ -72,6 +72,10 @@
#' be supplied, e.g. a [`Learner`][mlr3::Learner] or a [`Filter`][mlr3filters::Filter]; see [`as_pipeop()`].
#' The argument given as `op` is always cloned; to access a `Graph`'s [`PipeOp`]s by-reference, use `$pipeops`.\cr
#' Note that `$add_pipeop()` is a relatively low-level operation, it is recommended to build graphs using [`%>>%`].
#' * `remove_pipeop(id)` \cr
#' (`character(1)`) -> `self` \cr
#' Mutates [`Graph`] by removing the [`PipeOp`] with the matching id from the [`Graph`].
#' Corresponding edges are also removed as well as the corresponding [`ParamSet`][paradox::ParamSet].
#' * `add_edge(src_id, dst_id, src_channel = NULL, dst_channel = NULL)` \cr
#' (`character(1)`, `character(1)`,
#' `character(1)` | `numeric(1)` | `NULL`,
Expand All @@ -81,6 +85,10 @@
#' channel `dst_channel` (identified by its name or number as listed in the [`PipeOp`]'s `$input`).
#' If source or destination [`PipeOp`] have only one input / output channel and `src_channel` / `dst_channel`
#' are therefore unambiguous, they can be omitted (i.e. left as `NULL`).
#' * `replace_subgraph(ids, substitute)` \cr
#' (`character()`, [`Graph`] | [`PipeOp`] | [`Learner`][mlr3::Learner] | [`Filter`][mlr3filters::Filter] | `...`) -> `self` \cr
#' Mutates [`Graph`] by replacing a subgraph specified via ids with the supplied substitute subgraph.
#' Note that the supplied ids are always reordered in topological order with respect to the [`Graph`].
#' * `plot(html)` \cr
#' (`logical(1)`) -> `NULL` \cr
#' Plot the [`Graph`], using either the \pkg{igraph} package (for `html = FALSE`, default) or
Expand Down Expand Up @@ -248,6 +256,115 @@ Graph = R6Class("Graph",
invisible(self)
},

remove_pipeop = function(id) {
assert_subset(id, choices = self$ids(TRUE), empty.ok = FALSE)
self$pipeops[[id]] = NULL
self$edges = self$edges[src_id != id & dst_id != id]

if (!is.null(private$.param_set)) {
# param_set is built on-demand; if it has not been requested before, its value may be NULL
# and we don't need to remove anything.
private$.param_set$remove_sets(id)
}
invisible(self)
},

replace_subgraph = function(ids, substitute) {
# if this fails, pipeops, edges and param_set get reset
old_pipeops = self$pipeops
old_edges = self$edges
old_ps = private$.param_set
on.exit({
self$pipeops = old_pipeops
self$edges = old_edges
private$.param_set = old_ps
})

assert_subset(ids, choices = self$ids(TRUE), empty.ok = FALSE)
ids = self$ids(TRUE)[match(ids, self$ids(TRUE))] # always reorder ids topologically
substitute = as_graph(substitute, clone = TRUE)

# FIXME: check that ids are actually a valid subgraph of graph

# FIXME:
# check whether the input of the substitute is a vararg channel
#if (any(strip_multiplicity_type(substitute$input$channel.name) == "...")) {
# stopf("Using a substitute with a vararg input channel is not supported (yet).")
#}

# check whether the last id that is to be replaced connects to a varag channel
#if (nrow(self$edges)) { # this can be a data table with zero rows
# type = self$edges[src_id == range(ids)[2L], dst_channel]
# if (length(type)) { # can be of length 0 if this is the end of the graph
# if (strip_multiplicity_type(type) == "...") {
# stopf("Replacing a Subgraph that is connected to a vararg channel is not supported (yet).")
# }
# }
#}

input_orig = self$input
output_orig = self$output

for (id in ids) {
self$remove_pipeop(id) # also handles param_set
}

input = self$input[name != input_orig$name]
output = self$output[name != output_orig$name]

for (pipeop in substitute$pipeops) {
self$add_pipeop(pipeop) # also handles param_set
}
if (nrow(substitute$edges)) {
self$edges = rbind(self$edges, substitute$edges)
}

# FIXME: this reuses a lot of `%>>%`, we could write a general helper
# build edges from free output channels of substitute and free input channels of self
n_input = nrow(input)
if (n_input) {
# FIXME: check number of inputs / outputs
for (row in seq_len(n_input)) {
if (!are_types_compatible(strip_multiplicity_type(substitute$output$train[row]), strip_multiplicity_type(input$train[row]))) {
stopf("Output type of PipeOp %s during training (%s) incompatible with input type of PipeOp %s (%s)",
substitute$output$op.id[row], substitute$output$train[row], input$op.id[row], input$train[row])
}
if (!are_types_compatible(strip_multiplicity_type(substitute$output$predict[row]), strip_multiplicity_type(input$predict[row]))) {
stopf("Output type of PipeOp %s during prediction (%s) incompatible with input type of PipeOp %s (%s)",
substitute$output$op.id[row], substitute$output$predict[row], input$op.id[row], input$predict[row])
}
}
new_edges = cbind(substitute$output[, list(src_id = get("op.id"), src_channel = get("channel.name"))], input[, list(dst_id = get("op.id"), dst_channel = get("channel.name"))])
self$edges = rbind(self$edges, new_edges)
}

# build edges from free output channels of self and free input channels of substitute
n_output = nrow(output)
if (n_output) {
# FIXME: check number of inputs / outputs
for (row in seq_len(n_output)) {
if (!are_types_compatible(strip_multiplicity_type(output$train[row]), strip_multiplicity_type(substitute$input$train[row]))) {
stopf("Output type of PipeOp %s during training (%s) incompatible with input type of PipeOp %s (%s)",
output$op.id[row], output$train[row], substitute$input$op.id[row], substitute$input$train[row])
}
if (!are_types_compatible(strip_multiplicity_type(output$predict[row]), strip_multiplicity_type(substitute$input$predict[row]))) {
stopf("Output type of PipeOp %s during prediction (%s) incompatible with input type of PipeOp %s (%s)",
output$op.id[row], output$predict[row], substitute$input$op.id[row], substitute$input$predict[row])
}
}
new_edges = cbind(output[, list(src_id = get("op.id"), src_channel = get("channel.name"))], substitute$input[, list(dst_id = get("op.id"), dst_channel = get("channel.name"))])
self$edges = rbind(self$edges, new_edges)
}

# check if valid DAG
invisible(tryCatch(self$ids(TRUE), error = function(error_condition) {
stopf("Failed to infer new Graph structure. Resetting.")
}))

on.exit({})
invisible(self)
},

plot = function(html = FALSE) {
assert_flag(html)
if (!length(self$pipeops)) {
Expand Down
2 changes: 1 addition & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ register_mlr3 = function() {
} # nocov end

# static code checks should not complain about commonly used data.table columns
utils::globalVariables(c("src_id", "dst_id", "name", "op.id", "response", "truth"))
utils::globalVariables(c("src_id", "dst_id", "src_channel", "dst_channel", "name", "op.id", "response", "truth"))

leanify_package()
10 changes: 9 additions & 1 deletion man/Graph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

80 changes: 78 additions & 2 deletions tests/testthat/test_Graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ test_that("linear graph", {
g = Graph$new()
expect_equal(g$ids(sorted = TRUE), character(0))

# FIXME: we should "dummy" ops, so we can change properties of the ops at will
# FIXME: we should use "dummy" ops, so we can change properties of the ops at will
# we should NOT use PipeOpNOP, because we want to check that $train/$predict actually does something.
# FIXME: we should packages of the graph
# FIXME: we should check packages of the graph
op_ds = PipeOpSubsample$new()
op_pca = PipeOpPCA$new()
op_lrn = PipeOpLearner$new(mlr_learners$get("classif.rpart"))
Expand Down Expand Up @@ -434,3 +434,79 @@ test_that("dot output", {
"6 [label=\"OUTPUT",
"nop_output\",fontsize=24]"), out[-c(1L, 15L)])
})



test_that("replace_subgraph", {
task = tsk("iris")

# Basics
gr = Graph$new()$add_pipeop(PipeOpDebugMulti$new(2, 2))
address_old = address(gr)
gr_old = gr$clone(deep = TRUE)
expect_error(gr$replace_subgraph("id_not_present", PipeOpDebugMulti$new(2, 2)),
regexp = "Assertion on 'ids' failed")
expect_error(gr$replace_subgraph("debug.multi", NULL),
regexp = "op can not be converted to PipeOp")
expect_equal(gr, gr_old) # error results in a clean reset
expect_true(address_old == address(gr))
expect_deep_clone(gr_old, gr)

gr$replace_subgraph("debug.multi", substitute = PipeOpDebugMulti$new(2, 2))
expect_equal(gr_old, gr)
expect_true(address_old == address(gr)) # in place modification
expect_deep_clone(gr_old, gr) # replacing with exactly the same pipeop is the same as a deep clone

# Linear Graph
gr = po("scale") %>>% po("pca") %>>% lrn("classif.rpart")
gr_old = gr$clone(deep = TRUE)
gr$replace_subgraph("scale", substitute = po("scalemaxabs")) # replace beginning
expect_set_equal(gr$ids(), c("scalemaxabs", "pca", "classif.rpart"))
expect_true(gr$input$op.id == "scalemaxabs")
expect_true(gr$output$op.id == "classif.rpart")
expect_null(gr$train(task)[[1L]])
expect_prediction_classif(gr$predict(task)[[1L]])

gr = gr_old$clone(deep = TRUE)
gr$replace_subgraph("classif.rpart", substitute = lrn("classif.featureless")) # replace end
expect_set_equal(gr$ids(), c("scale", "pca", "classif.featureless"))
expect_true(gr$input$op.id == "scale")
expect_true(gr$output$op.id == "classif.featureless")
expect_null(gr$train(task)[[1L]])
expect_prediction_classif(gr$predict(task)[[1L]])

gr = gr_old$clone(deep = TRUE)
gr$replace_subgraph(c("scale", "pca", "classif.rpart"), substitute = po("scalemaxabs") %>>% po("ica") %>>% lrn("classif.featureless")) # replace whole graph
expect_set_equal(gr$ids(), c("scalemaxabs", "ica", "classif.featureless"))
expect_true(gr$input$op.id == "scalemaxabs")
expect_true(gr$output$op.id == "classif.featureless")
expect_null(gr$train(task)[[1L]])
expect_prediction_classif(gr$predict(task)[[1L]])

gr = gr_old$clone(deep = TRUE)
gr$replace_subgraph(c("pca", "scale"), substitute = po("scalemaxabs") %>>% po("ica")) # replace linear subgraph
expect_set_equal(gr$ids(), c("scalemaxabs", "ica", "classif.rpart"))
expect_true(gr$input$op.id == "scalemaxabs")
expect_true(gr$output$op.id == "classif.rpart")
expect_null(gr$train(task)[[1L]])
expect_prediction_classif(gr$predict(task)[[1L]])

# Non linear Graph
gr = po("scale") %>>% po("branch", c("pca", "nop")) %>>% gunion(list(po("pca"), po("nop"))) %>>% po("unbranch") %>>% lrn("classif.rpart")
gr_old = gr$clone(deep = TRUE)
#expect_error(gr$replace_subgraph(c("nop"), substitute = po("ica")), regexp = "connected to a vararg channel is not supported") # FIXME:
expect_error(gr$replace_subgraph(c("branch", "pca", "nop", "unbranch"), substitute = lrn("classif.featureless")),
regexp = "Output type of PipeOp classif.featureless during training")
gr$replace_subgraph(c("branch", "pca", "nop", "unbranch"), substitute = po("branch", c("pca", "ica")) %>>% gunion(list(po("pca"), po("ica"))) %>>% po("unbranch"))
expect_set_equal(gr$ids(TRUE), c("scale", "branch", "pca", "ica", "unbranch", "classif.rpart"))
expect_true(gr$input$op.id == "scale")
expect_true(gr$output$op.id == "classif.rpart")
expect_null(gr$train(task)[[1L]])
state1 = gr$state
gr$param_set$values$branch.selection = "ica"
expect_null(gr$train(task)[[1L]])
state2 = gr$state
expect_true(test_r6(state1$ica, classes = "NO_OP"))
expect_true(test_r6(state2$pca, classes = "NO_OP"))
expect_prediction_classif(gr$predict(task)[[1L]])
})