From 50511397c1c19ca474f6398b379827a36add8adb Mon Sep 17 00:00:00 2001 From: mb706 Date: Sun, 14 Jan 2024 13:09:26 +0100 Subject: [PATCH] test_dictionary with new paramset --- tests/testthat/test_dictionary.R | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/tests/testthat/test_dictionary.R b/tests/testthat/test_dictionary.R index e3ae20d97..1a598668c 100644 --- a/tests/testthat/test_dictionary.R +++ b/tests/testthat/test_dictionary.R @@ -117,31 +117,37 @@ test_that("Dictionary contains all PipeOps", { expect_equal(other_obj$phash, test_obj$phash, info = paste(dictname, "$new id test 2")) expect_equal(inflate(do.call(pogen$new, args)), test_obj, info = dictname) + + tops = test_obj$param_set # we now check if hyperparameters can be changed through construction # we do this by automatically generating a hyperparameter value that deviates from the automatically constructed one. # However, for ParamUty we can't do that, so if there are only 'ParamUty' parameter we skip this part. - eligibleparams = test_obj$param_set$params[test_obj$param_set$class != "ParamUty"] - eligibleparams = discard(eligibleparams, function(p) { - # filter out discrete params with only one level, or the numeric parameters with $lower == $upper - # The use '&&' here is intentional, because numeric parameters have 0 levels, and discrete parameters have $lower == $upper (== NA) - length(p$levels) < 2 && isTRUE(all.equal(p$lower, p$upper)) - }) + eligibleparams = which( + tops$class != "ParamUty" & + # filter out discrete params with only one level, or the numeric parameters with $lower == $upper + # Note that numeric parameters have 0 levels, and discrete parameters have $lower == $upper (== NA) + ( + (!is.na(tops$lower) & tops$lower != tops$upper) | + (is.finite(tops$nlevels) & tops$nlevels > 1) + ) + ) if (length(eligibleparams)) { - testingparam = eligibleparams[[1]] + testingparam = tops$ids()[[eligibleparams[[1]]]] # we want to construct an object where the parameter value is *different* from the value it gets on construction by default. # For this we take a few candidate values and `setdiff` the original value - origval = as.atomic(test_obj$param_set$values[[testingparam$id]]) - if (testingparam$class %in% c("ParamLgl", "ParamFct")) { - candidates = testingparam$levels + origval = as.atomic(test_obj$param_set$values[[testingparam]]) + if (tops$class[[testingparam]] %in% c("ParamLgl", "ParamFct")) { + candidates = tops$levels[[testingparam]] } else { - candidates = Filter(function(x) is.finite(x) && !is.na(x), c(testingparam$lower, testingparam$upper, testingparam$lower + 1, 0, origval + 1)) + candidates = Filter(function(x) is.finite(x) && !is.na(x), + c(tops$lower[[testingparam]], tops$upper[[testingparam]], tops$lower[[testingparam]] + 1, 0, origval + 1)) } val = setdiff(candidates, origval)[1] # construct the `param_vals = list(PARNAME = PARVAL)` construction argument args$param_vals = list(val) - names(args$param_vals) = testingparam$id + names(args$param_vals) = testingparam # check that the constructed object is different from the test_obj, but setting the test_obj's parameter # makes them equal again. @@ -152,7 +158,7 @@ test_that("Dictionary contains all PipeOps", { # phash should be independent of this! expect_true(isTRUE(all.equal(dict_constructed$phash, test_obj$phash)), dictname) - test_obj$param_set$values[[testingparam$id]] = val + test_obj$param_set$values[[testingparam]] = val expect_equal(touch(dict_constructed), test_obj) expect_equal(inflate(touch(gen_constructed)), test_obj)