Skip to content

Commit

Permalink
test_dictionary with new paramset
Browse files Browse the repository at this point in the history
  • Loading branch information
mb706 committed Jan 14, 2024
1 parent 42db529 commit 5051139
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions tests/testthat/test_dictionary.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,31 +117,37 @@ test_that("Dictionary contains all PipeOps", {
expect_equal(other_obj$phash, test_obj$phash, info = paste(dictname, "$new id test 2"))
expect_equal(inflate(do.call(pogen$new, args)), test_obj, info = dictname)


tops = test_obj$param_set
# we now check if hyperparameters can be changed through construction
# we do this by automatically generating a hyperparameter value that deviates from the automatically constructed one.
# However, for ParamUty we can't do that, so if there are only 'ParamUty' parameter we skip this part.
eligibleparams = test_obj$param_set$params[test_obj$param_set$class != "ParamUty"]
eligibleparams = discard(eligibleparams, function(p) {
# filter out discrete params with only one level, or the numeric parameters with $lower == $upper
# The use '&&' here is intentional, because numeric parameters have 0 levels, and discrete parameters have $lower == $upper (== NA)
length(p$levels) < 2 && isTRUE(all.equal(p$lower, p$upper))
})
eligibleparams = which(
tops$class != "ParamUty" &
# filter out discrete params with only one level, or the numeric parameters with $lower == $upper
# Note that numeric parameters have 0 levels, and discrete parameters have $lower == $upper (== NA)
(
(!is.na(tops$lower) & tops$lower != tops$upper) |
(is.finite(tops$nlevels) & tops$nlevels > 1)
)
)
if (length(eligibleparams)) {
testingparam = eligibleparams[[1]]
testingparam = tops$ids()[[eligibleparams[[1]]]]

# we want to construct an object where the parameter value is *different* from the value it gets on construction by default.
# For this we take a few candidate values and `setdiff` the original value
origval = as.atomic(test_obj$param_set$values[[testingparam$id]])
if (testingparam$class %in% c("ParamLgl", "ParamFct")) {
candidates = testingparam$levels
origval = as.atomic(test_obj$param_set$values[[testingparam]])
if (tops$class[[testingparam]] %in% c("ParamLgl", "ParamFct")) {
candidates = tops$levels[[testingparam]]
} else {
candidates = Filter(function(x) is.finite(x) && !is.na(x), c(testingparam$lower, testingparam$upper, testingparam$lower + 1, 0, origval + 1))
candidates = Filter(function(x) is.finite(x) && !is.na(x),
c(tops$lower[[testingparam]], tops$upper[[testingparam]], tops$lower[[testingparam]] + 1, 0, origval + 1))
}
val = setdiff(candidates, origval)[1]

# construct the `param_vals = list(PARNAME = PARVAL)` construction argument
args$param_vals = list(val)
names(args$param_vals) = testingparam$id
names(args$param_vals) = testingparam

# check that the constructed object is different from the test_obj, but setting the test_obj's parameter
# makes them equal again.
Expand All @@ -152,7 +158,7 @@ test_that("Dictionary contains all PipeOps", {
# phash should be independent of this!
expect_true(isTRUE(all.equal(dict_constructed$phash, test_obj$phash)), dictname)

test_obj$param_set$values[[testingparam$id]] = val
test_obj$param_set$values[[testingparam]] = val
expect_equal(touch(dict_constructed), test_obj)
expect_equal(inflate(touch(gen_constructed)), test_obj)

Expand Down

0 comments on commit 5051139

Please sign in to comment.