Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support arrdist #331

Merged
merged 38 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d5df42d
change input to weighted_survival_score() for Graf score + tests
bblodfon Sep 8, 2023
72b33ad
add support for distr6::Arrdist prediction type
bblodfon Sep 8, 2023
10e48aa
Merge pull request #330 from bblodfon/main
RaphaelS1 Sep 9, 2023
cb549c2
extend surv_return() to handle survival arrays + tests
bblodfon Sep 10, 2023
ed430ee
revert graf change
RaphaelS1 Sep 10, 2023
ff8cdbd
pass names when input is a vector and 'times' is not given
bblodfon Sep 11, 2023
6b74197
change the default curve to median
bblodfon Sep 11, 2023
865bf03
update test, add edge cases
bblodfon Sep 11, 2023
27e18d2
speeding up RCLL by refactoring a bit
bblodfon Sep 11, 2023
a65d691
fix bug: filtering PredictionSurv obj works with 3d survival arrays
bblodfon Sep 11, 2023
d1d8cfa
Merge branch 'main' into support_arrdist
RaphaelS1 Sep 15, 2023
ddcf3c1
Update DESCRIPTION
RaphaelS1 Sep 15, 2023
0f1a5dd
Update DESCRIPTION
RaphaelS1 Sep 15, 2023
43f244a
better example
bblodfon Sep 18, 2023
66282af
revert to original distrification
bblodfon Sep 18, 2023
6df6afe
better doc
bblodfon Sep 18, 2023
edd33f8
fix bug when input is a 3d survival array
bblodfon Sep 18, 2023
87d69b2
update example
bblodfon Sep 19, 2023
7b70fd8
fix bug in weighted_survival_score()
bblodfon Sep 19, 2023
a8306e4
code optimization
bblodfon Sep 25, 2023
86dd71c
code optimization
bblodfon Sep 25, 2023
293c4c1
test distr measures with 3d survival array
bblodfon Sep 25, 2023
44e1551
fix R CMD check warnings
bblodfon Sep 25, 2023
33bdc13
revert changes that tested graf score results numerically
bblodfon Sep 25, 2023
af3810c
bug fix + support combining survival array distrs
bblodfon Oct 4, 2023
d22a8ef
refactor helper function to 3d-ify a survival matrix
bblodfon Oct 4, 2023
3f59e53
add tests to combine 3d survival arrays
bblodfon Oct 4, 2023
dfc97f0
add 'which.curve' parameter to integrated scores and pecs + tests
bblodfon Oct 4, 2023
7237f70
move argument 'which.curve' into last position
bblodfon Oct 6, 2023
1fabbd7
Revert "add 'which.curve' parameter to integrated scores and pecs + t…
bblodfon Oct 6, 2023
085e87a
sapply => lapply
bblodfon Oct 6, 2023
c5bec7b
update distr6 dependency
RaphaelS1 Oct 8, 2023
0ae56d0
fix 2 small bugs
bblodfon Oct 13, 2023
707482e
update tests (combining different prediction types)
bblodfon Oct 13, 2023
6b01f57
combining survival matrices with arrays is now supported
bblodfon Oct 15, 2023
36e6b90
pump distr6
bblodfon Oct 15, 2023
762d807
correct distr6 PR
bblodfon Oct 15, 2023
a4b9d88
Update DESCRIPTION
RaphaelS1 Oct 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.5.2
Version: 0.5.3
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down Expand Up @@ -43,7 +43,7 @@ Depends:
Imports:
checkmate,
data.table,
distr6 (>= 1.6.11),
distr6 (>= 1.8.0),
ggplot2,
mlr3misc (>= 0.7.0),
mlr3viz,
Expand All @@ -62,7 +62,7 @@ Suggests:
param6 (>= 0.2.4),
pracma,
rpart,
set6 (>= 0.1.7),
set6 (>= 0.2.6),
simsurv,
survAUC,
testthat,
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureRegrLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ MeasureRegrLogloss = R6::R6Class("MeasureRegrLogloss",
distr = prediction$distr
truth = prediction$truth

if (inherits(distr, "Matdist")) {
if (inherits(distr, c("Matdist", "Arrdist"))) {
pdf = diag(distr$pdf(truth))
} else {
pdf = as.numeric(distr$pdf(data = matrix(truth, nrow = 1)))
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvIntLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ MeasureSurvIntLogloss = R6::R6Class("MeasureSurvIntLogloss",
}

score = weighted_survival_score("intslogloss", truth = prediction$truth,
distribution = prediction$distr, times = ps$times, t_max = ps$t_max,
distribution = prediction$data$distr, times = ps$times, t_max = ps$t_max,
p_max = ps$p_max, proper = ps$proper, train = train, eps = ps$eps)

if (ps$se) {
Expand Down
9 changes: 5 additions & 4 deletions R/MeasureSurvRCLL.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,17 @@ MeasureSurvRCLL = R6::R6Class("MeasureSurvRCLL",
event = truth[, 2] == 1
event_times = truth[event, 1]
cens_times = truth[!event, 1]
distr = prediction$distr

if (!any(event)) { # all censored
# survival at outcome time (survived *at least* this long)
out[!event] = diag(as.matrix(prediction$distr[!event]$survival(cens_times)))
out[!event] = diag(as.matrix(distr[!event]$survival(cens_times)))
} else if (all(event)) { # all uncensored
# pdf at outcome time (survived *this* long)
out[event] = diag(as.matrix(prediction$distr[event]$pdf(event_times)))
out[event] = diag(as.matrix(distr[event]$pdf(event_times)))
} else { # mix
out[event] = diag(as.matrix(prediction$distr[event]$pdf(event_times)))
out[!event] = diag(as.matrix(prediction$distr[!event]$survival(cens_times)))
out[event] = diag(as.matrix(distr[event]$pdf(event_times)))
out[!event] = diag(as.matrix(distr[!event]$survival(cens_times)))
}
bblodfon marked this conversation as resolved.
Show resolved Hide resolved

stopifnot(!any(out == -99L)) # safety check
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvSchmid.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ MeasureSurvSchmid = R6::R6Class("MeasureSurvSchmid",
}

score = weighted_survival_score("schmid", truth = prediction$truth,
distribution = prediction$distr, times = ps$times, t_max = ps$t_max,
distribution = prediction$data$distr, times = ps$times, t_max = ps$t_max,
p_max = ps$p_max, proper = ps$proper, train = train, eps = ps$eps)

if (ps$se) {
Expand Down
2 changes: 1 addition & 1 deletion R/PipeOpSurvAvg.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ delayedAssign(
distr = map(inputs, "distr")

ok = mlr3misc::map_lgl(distr, function(.x) {
checkmate::test_class(.x, "Matdist")
checkmate::test_class(.x, "Matdist") | checkmate::test_class(.x, "Arrdist")
})

if (all(ok)) {
Expand Down
2 changes: 1 addition & 1 deletion R/PipeOpTaskSurvRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ PipeOpTaskSurvRegr = R6Class("PipeOpTaskSurvRegr",
}

est = est$train(task)$predict(task)$distr
if (inherits(est, "Matdist")) {
if (inherits(est, c("Matdist", "Arrdist"))) {
weights = diag(est$survival(task$truth()[, 1]))
} else {
weights = as.numeric(est$survival(data = matrix(task$truth()[, 1], nrow = 1)))
Expand Down
13 changes: 10 additions & 3 deletions R/PredictionDataSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ check_prediction_data.PredictionDataSurv = function(pdata, ...) { # nolint
assert_numeric(pdata$lp, len = n, any.missing = FALSE, null.ok = TRUE)
if (inherits(pdata$distr, "VectorDistribution")) {
assert(nrow(pdata$distr$modelTable) == n)
} else if (inherits(pdata$distr, "Matdist")) {
} else if (inherits(pdata$distr, c("Matdist", "Arrdist"))) {
assert(nrow(gprm(pdata$distr, "pdf")) == n)
} else if (class(pdata$distr)[1] == "array") { # from Arrdist
assert_array(pdata$distr, d = 3, any.missing = FALSE, null.ok = TRUE)
} else {
assert_matrix(pdata$distr, nrows = n, any.missing = FALSE, null.ok = TRUE)
}
Expand Down Expand Up @@ -70,7 +72,7 @@ c.PredictionDataSurv = function(..., keep_duplicates = TRUE) {
}

if ("distr" %in% predict_types) {
if (inherits(dots[[1]], c("Matdist", "VectorDistribution"))) {
if (inherits(dots[[1]], c("Matdist", "VectorDistribution", "Arrdist"))) {
result$distr = do.call(c, map(dots, "distr"))
} else {
result$distr = tryCatch(
Expand Down Expand Up @@ -107,7 +109,12 @@ filter_prediction_data.PredictionDataSurv = function(pdata, row_ids, ...) {
}

if (!is.null(pdata$distr)) {
pdata$distr = pdata$distr[keep, , drop = FALSE]
if (inherits(pdata$distr, "matrix")) {
pdata$distr = pdata$distr[keep, , drop = FALSE]
} else { # array
pdata$distr = pdata$distr[keep, , , drop = FALSE]
}

}

pdata
Expand Down
53 changes: 39 additions & 14 deletions R/PredictionSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,36 @@
#' library(mlr3)
#' task = tsk("rats")
#' learner = lrn("surv.kaplan")
#' p = learner$train(task, row_ids = 1:20)$predict(task, row_ids = 21:30)
#' p = learner$train(task, row_ids = 1:26)$predict(task, row_ids = 27:30)
#' head(as.data.table(p))
#' # survival probabilities of the 4 test rats at two time points
#' p$distr$survival(c(20, 100))
PredictionSurv = R6Class("PredictionSurv",
inherit = Prediction,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @details
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
#' Upon **initialization**, if the `distr` input is a [Distribution][distr6::Distribution],
#' we try to coerce it either to a survival matrix or a survival array and store it
#' in the `$data$distr` slot for internal use.
#'
#' If the stored `$data$distr` is a [Distribution][distr6::Distribution] object,
#' the active field `$distr` (**external user API**) returns it without modification.
#' Otherwise, if `$data$distr` is a survival matrix or array, `$distr`
#' constructs a distribution out of the `$data$distr` object, which will be a
#' [Matdist][distr6::Matdist] or [Arrdist][distr6::Arrdist] respectively.
#'
#' Note that if a survival 3d array is stored in `$data$distr`, the `$distr`
#' field returns an [Arrdist][distr6::Arrdist] initialized with `which.curve = 0.5`
#' by default (i.e. the median curve). This means that measures that require
#' a `distr` prediction like [MeasureSurvGraf], [MeasureSurvRCLL], etc.
#' will use the median survival probabilities.
#' Note that it is possible to manually change `which.curve` after construction
#' of the predicted distribution but we advise against this as it may lead to
#' inconsistent results.
#'
#' @param task ([TaskSurv])\cr
#' Task, used to extract defaults for `row_ids` and `truth`.
#'
Expand All @@ -33,10 +55,10 @@ PredictionSurv = R6Class("PredictionSurv",
#' observation in the test set. For a pair of continuous ranks, a higher rank indicates that
#' the observation is more likely to experience the event.
#'
#' @param distr (`matrix()|[distr6::Matdist]|[distr6::VectorDistribution]`)\cr
#' Either a matrix of predicted survival probabilities or a [distr6::VectorDistribution]
#' or a [distr6::Matdist].
#' If a matrix then column names must be given and correspond to survival times.
#' @param distr (`matrix()|[distr6::Arrdist]|[distr6::Matdist]|[distr6::VectorDistribution]`)\cr
#' Either a matrix of predicted survival probabilities, a [distr6::VectorDistribution],
#' a [distr6::Matdist] or an [distr6::Arrdist].
#' If a matrix/array then column names must be given and correspond to survival times.
#' Rows of matrix correspond to individual predictions. It is advised that the
#' first column should be time `0` with all entries `1` and the last
#' with all entries `0`. If a `VectorDistribution` then each distribution in the vector
Expand All @@ -57,7 +79,7 @@ PredictionSurv = R6Class("PredictionSurv",
distr = NULL, lp = NULL, response = NULL, check = TRUE) {

if (inherits(distr, "Distribution")) {
# coerce to matrix if possible
# coerce to matrix/array if possible
distr = private$.simplify_distr(distr)
}

Expand Down Expand Up @@ -90,14 +112,14 @@ PredictionSurv = R6Class("PredictionSurv",
self$data$crank %??% rep(NA_real_, length(self$data$row_ids))
},

#' @field distr ([distr6::Matdist]|[distr6::VectorDistribution])\cr
#' Convert the stored survival matrix to a survival distribution.
#' @field distr ([distr6::Matdist]|[distr6::Arrdist]|[distr6::VectorDistribution])\cr
#' Convert the stored survival array or matrix to a survival distribution.
distr = function() {
if (inherits(self$data$distr, "Distribution")) {
return(self$data$distr)
}

private$.distrify_survmatrix(self$data$distr %??% NA_real_)
private$.distrify_survarray(self$data$distr %??% NA_real_)
},

#' @field lp (`numeric()`)\cr
Expand All @@ -117,8 +139,8 @@ PredictionSurv = R6Class("PredictionSurv",
.censtype = NULL,
.distr = function() self$data$distr %??% NA_real_,
.simplify_distr = function(x) {
if (inherits(x, "Matdist")) {
1 - gprm(x, "cdf")
if (inherits(x, c("Matdist", "Arrdist"))) {
1 - gprm(x, "cdf") # matrix or 3d array
} else {
if (!inherits(x, "VectorDistribution")) {
stop("'x' is not a 'VectorDistribution'")
Expand Down Expand Up @@ -148,9 +170,12 @@ PredictionSurv = R6Class("PredictionSurv",
surv
}
},
.distrify_survmatrix = function(x) {
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
distr6::as.Distribution(1 - x, fun = "cdf",
decorators = c("CoreStatistics", "ExoticStatistics"))
.distrify_survarray = function(x) {
if (inherits(x, "array")) { # can be matrix as well
# create Matdist or Arrdist (default => median curve)
distr6::as.Distribution(1 - x, fun = "cdf",
decorators = c("CoreStatistics", "ExoticStatistics"))
}
}
)
)
Expand Down
16 changes: 12 additions & 4 deletions R/integrated_scores.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,26 @@ weighted_survival_score = function(loss, truth, distribution, times, t_max, p_ma
unique_times = .c_get_unique_times(truth[, "time"], times)
}

# get the cdf matrix (rows => times, cols => obs)
if (inherits(distribution, "Distribution")) {
cdf = as.matrix(distribution$cdf(unique_times))
} else {
mtc = findInterval(unique_times, as.numeric(colnames(distribution)))
cdf = 1 - t(distribution[, mtc])
}
else if (inherits(distribution, "array")) {
if (length(dim(distribution)) == 3) {
# survival 3d array, extract median
surv_mat = .ext_surv_mat(arr = distribution, which.curve = 0.5)
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
} else { # survival 2d array
surv_mat = distribution
}
mtc = findInterval(unique_times, as.numeric(colnames(surv_mat)))
cdf = 1 - t(surv_mat[, mtc])
if (any(mtc == 0)) {
cdf = rbind(matrix(0, sum(mtc == 0), ncol(cdf)), cdf)
}
rownames(cdf) = unique_times
}

true_times <- truth[, "time"]
true_times = truth[, "time"]

assert_numeric(true_times, any.missing = FALSE)
assert_numeric(unique_times, any.missing = FALSE)
Expand Down
8 changes: 4 additions & 4 deletions R/pecs.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ pecs.list = function(x, measure = c("graf", "logloss"), times, n, eps = NULL, ta
scores = lapply(p, function(y) {
integrated_score(score = weighted_survival_score("intslogloss",
truth = task$truth(),
distribution = y$distr,
distribution = y$data$distr,
times = times,
eps = eps, train = train,
proper = proper),
Expand All @@ -123,7 +123,7 @@ pecs.list = function(x, measure = c("graf", "logloss"), times, n, eps = NULL, ta
scores = lapply(p, function(y) {
integrated_score(score = weighted_survival_score("graf",
truth = task$truth(),
distribution = y$distr,
distribution = y$data$distr,
times = times, train = train, eps = eps,
proper = proper),
integrated = FALSE)
Expand Down Expand Up @@ -169,15 +169,15 @@ pecs.PredictionSurv = function(x, measure = c("graf", "logloss"), times, n, eps
scores = data.frame(logloss = integrated_score(
score = weighted_survival_score("intslogloss",
truth = x$truth,
distribution = x$distr,
distribution = x$data$distr,
times = times,
eps = eps, train = train, proper = proper),
integrated = FALSE))
} else {
scores = data.frame(graf = integrated_score(
score = weighted_survival_score("graf",
truth = x$truth,
distribution = x$distr,
distribution = x$data$distr,
times = times, train = train, eps = eps,
proper = proper),
integrated = FALSE))
Expand Down
4 changes: 2 additions & 2 deletions R/surv_measures.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
surv_logloss = function(truth, distribution, eps = 1e-15, IPCW = TRUE, train = NULL, ...) {

# calculate pdf at true death time
if (inherits(distribution, "Matdist")) {
if (inherits(distribution, c("Matdist", "Arrdist"))) {
pred = diag(distribution$pdf(truth[, 1]))
} else {
pred = as.numeric(distribution$pdf(data = matrix(truth[, 1], nrow = 1)))
Expand Down Expand Up @@ -36,7 +36,7 @@ surv_logloss = function(truth, distribution, eps = 1e-15, IPCW = TRUE, train = N
truth = truth[uncensored, 1]
distribution = distribution[uncensored]

if (inherits(distribution, "Matdist")) {
if (inherits(distribution, c("Matdist", "Arrdist"))) {
cens = diag(distribution$survival(truth))
} else {
cens = as.numeric(distribution$survival(data = matrix(truth, nrow = 1)))
Expand Down
Loading