Skip to content

Commit

Permalink
fix small issues when filtering predictions + add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bblodfon committed Nov 11, 2023
1 parent 54561e2 commit f0e6dfb
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 6 deletions.
8 changes: 7 additions & 1 deletion R/PredictionDataSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,13 @@ filter_prediction_data.PredictionDataSurv = function(pdata, row_ids, ...) {
distr = pdata$distr

if (testDistribution(distr)) { # distribution
pdata$distr = distr[keep]
ok = inherits(distr, c("VectorDistribution", "Matdist", "Arrdist")) &&
length(keep) > 1 # edge case: Arrdist(1xYxZ) and keep = FALSE
if (ok) {
pdata$distr = distr[keep] # we can subset row/samples like this
} else {
pdata$distr = base::switch(keep, distr) # one distribution only
}
} else {
if (length(dim(distr)) == 2) { # 2d matrix
pdata$distr = distr[keep, , drop = FALSE]
Expand Down
4 changes: 3 additions & 1 deletion R/PredictionSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,12 @@ PredictionSurv = R6Class("PredictionSurv",
}
},
.distrify_survarray = function(x) {
if (inherits(x, "array")) { # can be matrix as well
if (inherits(x, "array") && nrow(x) > 0) { # can be matrix as well
# create Matdist or Arrdist (default => median curve)
distr6::as.Distribution(1 - x, fun = "cdf",
decorators = c("CoreStatistics", "ExoticStatistics"))
} else {
NULL
}
}
)
Expand Down
4 changes: 2 additions & 2 deletions inst/testthat/helper_expectations.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ expect_prediction_surv = function(p) {
"response", "distr", "lp", "crank"))
checkmate::expect_data_table(data.table::as.data.table(p), nrows = length(p$row_ids))
checkmate::expect_atomic_vector(p$missing)
if ("distr" %in% p$predict_types) {
expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist"))
if ("distr" %in% p$predict_types && !is.null(p$distr)) {
expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist", "WeightedDiscrete"))
}
expect_true(inherits(p, "PredictionSurv"))
}
34 changes: 32 additions & 2 deletions tests/testthat/test_PredictionSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ test_that("as_prediction_surv", {
})

test_that("filtering", {
p = suppressWarnings(lrn("surv.coxph")$train(task)$predict(task))
p2 = reshape_distr_to_3d(p) # survival array distr
p = suppressWarnings(lrn("surv.coxph")$train(task)$predict(task)) # survival matrix
p2 = reshape_distr_to_3d(p) # survival array
p3 = p$clone()
p4 = p2$clone()
p3$data$distr = p3$distr # Matdist
Expand Down Expand Up @@ -209,4 +209,34 @@ test_that("filtering", {
expect_equal(nrow(p2$data$distr), 3)
expect_true(inherits(p3$data$distr, "Matdist"))
expect_true(inherits(p4$data$distr, "Arrdist"))

# edge case: filter to 1 observation
p$filter(20)
p2$filter(20)
p3$filter(20)
p4$filter(20)
expect_prediction_surv(p)
expect_prediction_surv(p2)
expect_prediction_surv(p3)
expect_prediction_surv(p4)
expect_matrix(p$data$distr, nrows = 1)
expect_array(p2$data$distr, d = 3)
expect_equal(nrow(p2$data$distr), 1)
expect_true(inherits(p3$data$distr, "WeightedDiscrete")) # from Matdist!
expect_true(inherits(p4$data$distr, "Arrdist")) # remains an Arrdist!

# filter to 0 observations using non-existent (positive) id
p$filter(42)
p2$filter(42)
p3$filter(42)
p4$filter(42)

expect_prediction_surv(p)
expect_prediction_surv(p2)
expect_prediction_surv(p3)
expect_prediction_surv(p4)
expect_null(p$distr)
expect_null(p2$distr)
expect_null(p3$distr)
expect_null(p4$distr)
})

0 comments on commit f0e6dfb

Please sign in to comment.