Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support arrdist #331

Merged
merged 38 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d5df42d
change input to weighted_survival_score() for Graf score + tests
bblodfon Sep 8, 2023
72b33ad
add support for distr6::Arrdist prediction type
bblodfon Sep 8, 2023
10e48aa
Merge pull request #330 from bblodfon/main
RaphaelS1 Sep 9, 2023
cb549c2
extend surv_return() to handle survival arrays + tests
bblodfon Sep 10, 2023
ed430ee
revert graf change
RaphaelS1 Sep 10, 2023
ff8cdbd
pass names when input is a vector and 'times' is not given
bblodfon Sep 11, 2023
6b74197
change the default curve to median
bblodfon Sep 11, 2023
865bf03
update test, add edge cases
bblodfon Sep 11, 2023
27e18d2
speeding up RCLL by refactoring a bit
bblodfon Sep 11, 2023
a65d691
fix bug: filtering PredictionSurv obj works with 3d survival arrays
bblodfon Sep 11, 2023
d1d8cfa
Merge branch 'main' into support_arrdist
RaphaelS1 Sep 15, 2023
ddcf3c1
Update DESCRIPTION
RaphaelS1 Sep 15, 2023
0f1a5dd
Update DESCRIPTION
RaphaelS1 Sep 15, 2023
43f244a
better example
bblodfon Sep 18, 2023
66282af
revert to original distrification
bblodfon Sep 18, 2023
6df6afe
better doc
bblodfon Sep 18, 2023
edd33f8
fix bug when input is a 3d survival array
bblodfon Sep 18, 2023
87d69b2
update example
bblodfon Sep 19, 2023
7b70fd8
fix bug in weighted_survival_score()
bblodfon Sep 19, 2023
a8306e4
code optimization
bblodfon Sep 25, 2023
86dd71c
code optimization
bblodfon Sep 25, 2023
293c4c1
test distr measures with 3d survival array
bblodfon Sep 25, 2023
44e1551
fix R CMD check warnings
bblodfon Sep 25, 2023
33bdc13
revert changes that tested graf score results numerically
bblodfon Sep 25, 2023
af3810c
bug fix + support combining survival array distrs
bblodfon Oct 4, 2023
d22a8ef
refactor helper function to 3d-ify a survival matrix
bblodfon Oct 4, 2023
3f59e53
add tests to combine 3d survival arrays
bblodfon Oct 4, 2023
dfc97f0
add 'which.curve' parameter to integrated scores and pecs + tests
bblodfon Oct 4, 2023
7237f70
move argument 'which.curve' into last position
bblodfon Oct 6, 2023
1fabbd7
Revert "add 'which.curve' parameter to integrated scores and pecs + t…
bblodfon Oct 6, 2023
085e87a
sapply => lapply
bblodfon Oct 6, 2023
c5bec7b
update distr6 dependency
RaphaelS1 Oct 8, 2023
0ae56d0
fix 2 small bugs
bblodfon Oct 13, 2023
707482e
update tests (combining different prediction types)
bblodfon Oct 13, 2023
6b01f57
combining survival matrices with arrays is now supported
bblodfon Oct 15, 2023
36e6b90
pump distr6
bblodfon Oct 15, 2023
762d807
correct distr6 PR
bblodfon Oct 15, 2023
a4b9d88
Update DESCRIPTION
RaphaelS1 Oct 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.5.2
Version: 0.5.3
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down Expand Up @@ -43,7 +43,7 @@ Depends:
Imports:
checkmate,
data.table,
distr6 (>= 1.6.11),
distr6 (>= 1.8.0),
ggplot2,
mlr3misc (>= 0.7.0),
mlr3viz,
Expand All @@ -62,7 +62,7 @@ Suggests:
param6 (>= 0.2.4),
pracma,
rpart,
set6 (>= 0.1.7),
set6 (>= 0.2.6),
simsurv,
survAUC,
testthat,
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureRegrLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ MeasureRegrLogloss = R6::R6Class("MeasureRegrLogloss",
distr = prediction$distr
truth = prediction$truth

if (inherits(distr, "Matdist")) {
if (inherits(distr, c("Matdist", "Arrdist"))) {
pdf = diag(distr$pdf(truth))
} else {
pdf = as.numeric(distr$pdf(data = matrix(truth, nrow = 1)))
Expand Down
2 changes: 1 addition & 1 deletion R/PipeOpSurvAvg.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ delayedAssign(
distr = map(inputs, "distr")

ok = mlr3misc::map_lgl(distr, function(.x) {
checkmate::test_class(.x, "Matdist")
checkmate::test_class(.x, "Matdist") | checkmate::test_class(.x, "Arrdist")
})

if (all(ok)) {
Expand Down
2 changes: 1 addition & 1 deletion R/PipeOpTaskSurvRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ PipeOpTaskSurvRegr = R6Class("PipeOpTaskSurvRegr",
}

est = est$train(task)$predict(task)$distr
if (inherits(est, "Matdist")) {
if (inherits(est, c("Matdist", "Arrdist"))) {
weights = diag(est$survival(task$truth()[, 1]))
} else {
weights = as.numeric(est$survival(data = matrix(task$truth()[, 1], nrow = 1)))
Expand Down
6 changes: 4 additions & 2 deletions R/PredictionDataSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ check_prediction_data.PredictionDataSurv = function(pdata, ...) { # nolint
assert_numeric(pdata$lp, len = n, any.missing = FALSE, null.ok = TRUE)
if (inherits(pdata$distr, "VectorDistribution")) {
assert(nrow(pdata$distr$modelTable) == n)
} else if (inherits(pdata$distr, "Matdist")) {
} else if (inherits(pdata$distr, c("Matdist", "Arrdist"))) {
assert(nrow(gprm(pdata$distr, "pdf")) == n)
} else if (class(pdata$distr)[1] == "array") { # from Arrdist
assert_array(pdata$distr, d = 3, any.missing = FALSE, null.ok = TRUE)
} else {
assert_matrix(pdata$distr, nrows = n, any.missing = FALSE, null.ok = TRUE)
}
Expand Down Expand Up @@ -70,7 +72,7 @@ c.PredictionDataSurv = function(..., keep_duplicates = TRUE) {
}

if ("distr" %in% predict_types) {
if (inherits(dots[[1]], c("Matdist", "VectorDistribution"))) {
if (inherits(dots[[1]], c("Matdist", "VectorDistribution", "Arrdist"))) {
result$distr = do.call(c, map(dots, "distr"))
} else {
result$distr = tryCatch(
Expand Down
41 changes: 29 additions & 12 deletions R/PredictionSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,22 @@
#' learner = lrn("surv.kaplan")
#' p = learner$train(task, row_ids = 1:20)$predict(task, row_ids = 21:30)
#' head(as.data.table(p))
#' class(p$data$distr) # survival matrix stored
#' p$distr # Matdist
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
PredictionSurv = R6Class("PredictionSurv",
inherit = Prediction,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @details
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
#' Upon initialization, the `distr` input will be coerced to a survival matrix
#' or array (accessible via `$data$distr`) if it's a [Distribution][distr6::Distribution]
#' object. The active field `$distr` always returns a distribution
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
#' ([Matdist][distr6::Matdist] or [Arrdist][distr6::Arrdist]) depening on
#' the class of the stored `$data$distr`. In the case of an [Arrdist][distr6::Arrdist],
#' the distribution is by default initialized using `which.curve = 'mean'`.
#'
#' @param task ([TaskSurv])\cr
#' Task, used to extract defaults for `row_ids` and `truth`.
#'
Expand All @@ -33,10 +43,10 @@ PredictionSurv = R6Class("PredictionSurv",
#' observation in the test set. For a pair of continuous ranks, a higher rank indicates that
#' the observation is more likely to experience the event.
#'
#' @param distr (`matrix()|[distr6::Matdist]|[distr6::VectorDistribution]`)\cr
#' Either a matrix of predicted survival probabilities or a [distr6::VectorDistribution]
#' or a [distr6::Matdist].
#' If a matrix then column names must be given and correspond to survival times.
#' @param distr (`matrix()|[distr6::Arrdist]|[distr6::Matdist]|[distr6::VectorDistribution]`)\cr
#' Either a matrix of predicted survival probabilities, a [distr6::VectorDistribution],
#' a [distr6::Matdist] or an [distr6::Arrdist].
#' If a matrix/array then column names must be given and correspond to survival times.
#' Rows of matrix correspond to individual predictions. It is advised that the
#' first column should be time `0` with all entries `1` and the last
#' with all entries `0`. If a `VectorDistribution` then each distribution in the vector
Expand All @@ -57,7 +67,7 @@ PredictionSurv = R6Class("PredictionSurv",
distr = NULL, lp = NULL, response = NULL, check = TRUE) {

if (inherits(distr, "Distribution")) {
# coerce to matrix if possible
# coerce to matrix/array if possible
distr = private$.simplify_distr(distr)
}

Expand Down Expand Up @@ -90,14 +100,14 @@ PredictionSurv = R6Class("PredictionSurv",
self$data$crank %??% rep(NA_real_, length(self$data$row_ids))
},

#' @field distr ([distr6::Matdist]|[distr6::VectorDistribution])\cr
#' Convert the stored survival matrix to a survival distribution.
#' @field distr ([distr6::Matdist]|[distr6::Arrdist]|[distr6::VectorDistribution])\cr
#' Convert the stored survival array or matrix to a survival distribution.
distr = function() {
if (inherits(self$data$distr, "Distribution")) {
return(self$data$distr)
}

private$.distrify_survmatrix(self$data$distr %??% NA_real_)
private$.distrify_survarray(self$data$distr %??% NA_real_)
},

#' @field lp (`numeric()`)\cr
Expand All @@ -117,7 +127,7 @@ PredictionSurv = R6Class("PredictionSurv",
.censtype = NULL,
.distr = function() self$data$distr %??% NA_real_,
.simplify_distr = function(x) {
if (inherits(x, "Matdist")) {
if (inherits(x, c("Matdist", "Arrdist"))) {
1 - gprm(x, "cdf")
} else {
if (!inherits(x, "VectorDistribution")) {
Expand Down Expand Up @@ -148,9 +158,16 @@ PredictionSurv = R6Class("PredictionSurv",
surv
}
},
.distrify_survmatrix = function(x) {
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
distr6::as.Distribution(1 - x, fun = "cdf",
decorators = c("CoreStatistics", "ExoticStatistics"))
.distrify_survarray = function(x) {
if (inherits(x, "matrix")) {
# create Matdist
distr6::as.Distribution(1 - x, fun = "cdf",
decorators = c("CoreStatistics", "ExoticStatistics"))
} else {
# create Arrdist
distr6::Arrdist$new(cdf = 1 - x, which.curve = "mean",
decorators = c("CoreStatistics", "ExoticStatistics"))
}
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
}
)
)
Expand Down
4 changes: 2 additions & 2 deletions R/surv_measures.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
surv_logloss = function(truth, distribution, eps = 1e-15, IPCW = TRUE, train = NULL, ...) {

# calculate pdf at true death time
if (inherits(distribution, "Matdist")) {
if (inherits(distribution, c("Matdist", "Arrdist"))) {
pred = diag(distribution$pdf(truth[, 1]))
} else {
pred = as.numeric(distribution$pdf(data = matrix(truth[, 1], nrow = 1)))
Expand Down Expand Up @@ -36,7 +36,7 @@ surv_logloss = function(truth, distribution, eps = 1e-15, IPCW = TRUE, train = N
truth = truth[uncensored, 1]
distribution = distribution[uncensored]

if (inherits(distribution, "Matdist")) {
if (inherits(distribution, c("Matdist", "Arrdist"))) {
cens = diag(distribution$survival(truth))
} else {
cens = as.numeric(distribution$survival(data = matrix(truth, nrow = 1)))
Expand Down
87 changes: 75 additions & 12 deletions R/surv_return.R
Original file line number Diff line number Diff line change
@@ -1,36 +1,59 @@
#' @title Get Survival Predict Types
#' @description Internal helper function to easily return the correct survival predict types and to
#' automatically coerce a predicted survival probability matrix to a [distr6::Matdist].
#'
#' @description Internal helper function to easily return the correct survival predict types.
#'
#' @param times (`numeric()`) \cr Vector of survival times.
#' @param surv (`matrix()`)\cr Matrix of predicted survival probabilities, rows are observations,
#' columns are times. Number of columns should be equal to length of `times`.
#' @param surv (`matrix()|array()`)\cr Matrix or array of predicted survival
#' probabilities, rows (1st dimension) are observations, columns (2nd dimension)
#' are times and in the case of an array there should be one more dimension.
#' Number of columns should be equal to length of `times`. In case a `numeric()`
#' vector is provided, it is converted to a single row (one observation) matrix.
#' @param which.curve Which curve (3rd dimension) should the `crank` be
#' calculated for, in case `surv` is an `array`? If between (0,1) it is taken as
#' the quantile of the curves otherwise if greater than 1 it is taken as the
#' curve index. It can also be 'mean' and the survival probabilities are averaged
#' across the 3rd dimension. Default value (`NULL`) is the **0.5 quantile** which
#' is the median across the 3rd dimension of the survival array.
#' @param crank (`numeric()`)\cr Relative risk/continuous ranking. Higher value is associated
#' with higher risk. If `NULL` then either set as `-response` if available or
#' `lp` if available (this assumes that the `lp` prediction comes from a PH type
#' model - in case of an AFT model the user should provide `-lp`).
#' In case neither `response` or `lp` are provided, then `crank` is calculated
#' as the sum of the cumulative hazard function (expected mortality) derived from
#' the predicted survival function (`surv`).
#' the predicted survival function (`surv`). In case `surv` is a 3d array, we use
#' the `which.curve` parameter to decide which survival matrix (index in the 3rd
#' dimension) will be chosen for the calculation of `crank`.
#' @param lp (`numeric()`)\cr Predicted linear predictor, used to impute `crank` if `NULL`.
#' @param response (`numeric()`)\cr Predicted survival time, passed through function without
#' modification.
#'
#' @details
#' Uses [survivalmodels::surv_to_risk] to reduce survival matrices to relative
#' risks / rankings if `crank` is NULL.
#'
#' @references
#' Sonabend, R., Bender, A., & Vollmer, S. (2022). Avoiding C-hacking when
#' evaluating survival distribution predictions with discrimination measures.
#' Bioinformatics. https://doi.org/10.1093/BIOINFORMATICS/BTAC451
#' @export
.surv_return = function(times = NULL, surv = NULL, crank = NULL, lp = NULL, response = NULL) {
.surv_return = function(times = NULL, surv = NULL, which.curve = NULL,
crank = NULL, lp = NULL, response = NULL) {

if (!is.null(surv)) {
if (class(surv)[1] == "numeric") {
surv = matrix(surv, nrow = 1)
# in case of a vector (one observation) convert to matrix
surv = matrix(surv, nrow = 1, dimnames = list(NULL, names(surv)))
}
times <- times %||% colnames(surv)
assert(length(times) == ncol(surv))
colnames(surv) <- times
if (class(surv)[1] == "array") {
if (length(dim(surv)) != 3) {
stop("3D survival arrays supported only")
}
}
times = times %||% colnames(surv)
if (length(times) != ncol(surv)) {
stop("'times' must have the same length as the 2nd dimension (columns of 'surv')")
}
colnames(surv) = times
}

if (is.null(crank)) {
Expand All @@ -42,14 +65,54 @@
# assumes PH-type lp where high value = high risk
crank = lp
} else if (!is.null(surv)) {
crank = survivalmodels::surv_to_risk(surv)
if (inherits(surv, "matrix")) {
crank = survivalmodels::surv_to_risk(surv)
} else { # array
surv_mat = .ext_surv_mat(surv, which.curve)
crank = survivalmodels::surv_to_risk(surv_mat)
}
}
}

list(
distr = surv,
distr = surv, # matrix or array
crank = crank,
lp = lp,
response = response
)
}

# helper function to extract a survival matrix from a 3D survival array
.ext_surv_mat = function(arr, which.curve) {
# if NULL return the 'median' curve (default)
if (is.null(which.curve)) {
return(array(apply(arr, c(1, 2), quantile, 0.5), c(nrow(arr), ncol(arr)),
dimnames(arr)[c(1, 2)]))
}

# which.curve must be length 1 and either 'mean' or >0
ok = (length(which.curve) == 1) &&
((is.character(which.curve) && which.curve == "mean") ||
(is.numeric(which.curve) && which.curve > 0))
if (!ok) {
stop("'which.curve' has to be a numeric between (0,1) or the index of the
3rd dimension or 'mean'")
}

if (is.numeric(which.curve) && which.curve > dim(arr)[3L]) {
stop(sprintf("Length is %s on third dimension but curve '%s' requested,
change 'which.curve' parameter.", dim(arr)[3L], which.curve))
}

# mean
if (which.curve == "mean") {
apply(arr, c(1, 2), mean)
# curve chosen based on quantile
} else if (which.curve < 1) {
array(apply(arr, c(1, 2), quantile, which.curve), c(nrow(arr), ncol(arr)),
dimnames(arr)[c(1, 2)])
# curve chosen based on index
} else {
array(arr[, , which.curve], c(nrow(arr), ncol(arr)), dimnames(arr)[c(1, 2)])
}
}
2 changes: 1 addition & 1 deletion inst/testthat/helper_expectations.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ expect_prediction_surv = function(p) {
checkmate::expect_data_table(data.table::as.data.table(p), nrows = length(p$row_ids))
checkmate::expect_atomic_vector(p$missing)
if ("distr" %in% p$predict_types) {
expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist"))
expect_true(class(p$distr)[[1]] %in% c("VectorDistribution", "Matdist", "Arrdist"))
}
expect_true(inherits(p, "PredictionSurv"))
}
23 changes: 17 additions & 6 deletions man/PredictionSurv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading