Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support arrdist #331

Merged
merged 38 commits into from
Oct 16, 2023
Merged
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d5df42d
change input to weighted_survival_score() for Graf score + tests
bblodfon Sep 8, 2023
72b33ad
add support for distr6::Arrdist prediction type
bblodfon Sep 8, 2023
10e48aa
Merge pull request #330 from bblodfon/main
RaphaelS1 Sep 9, 2023
cb549c2
extend surv_return() to handle survival arrays + tests
bblodfon Sep 10, 2023
ed430ee
revert graf change
RaphaelS1 Sep 10, 2023
ff8cdbd
pass names when input is a vector and 'times' is not given
bblodfon Sep 11, 2023
6b74197
change the default curve to median
bblodfon Sep 11, 2023
865bf03
update test, add edge cases
bblodfon Sep 11, 2023
27e18d2
speeding up RCLL by refactoring a bit
bblodfon Sep 11, 2023
a65d691
fix bug: filtering PredictionSurv obj works with 3d survival arrays
bblodfon Sep 11, 2023
d1d8cfa
Merge branch 'main' into support_arrdist
RaphaelS1 Sep 15, 2023
ddcf3c1
Update DESCRIPTION
RaphaelS1 Sep 15, 2023
0f1a5dd
Update DESCRIPTION
RaphaelS1 Sep 15, 2023
43f244a
better example
bblodfon Sep 18, 2023
66282af
revert to original distrification
bblodfon Sep 18, 2023
6df6afe
better doc
bblodfon Sep 18, 2023
edd33f8
fix bug when input is a 3d survival array
bblodfon Sep 18, 2023
87d69b2
update example
bblodfon Sep 19, 2023
7b70fd8
fix bug in weighted_survival_score()
bblodfon Sep 19, 2023
a8306e4
code optimization
bblodfon Sep 25, 2023
86dd71c
code optimization
bblodfon Sep 25, 2023
293c4c1
test distr measures with 3d survival array
bblodfon Sep 25, 2023
44e1551
fix R CMD check warnings
bblodfon Sep 25, 2023
33bdc13
revert changes that tested graf score results numerically
bblodfon Sep 25, 2023
af3810c
bug fix + support combining survival array distrs
bblodfon Oct 4, 2023
d22a8ef
refactor helper function to 3d-ify a survival matrix
bblodfon Oct 4, 2023
3f59e53
add tests to combine 3d survival arrays
bblodfon Oct 4, 2023
dfc97f0
add 'which.curve' parameter to integrated scores and pecs + tests
bblodfon Oct 4, 2023
7237f70
move argument 'which.curve' into last position
bblodfon Oct 6, 2023
1fabbd7
Revert "add 'which.curve' parameter to integrated scores and pecs + t…
bblodfon Oct 6, 2023
085e87a
sapply => lapply
bblodfon Oct 6, 2023
c5bec7b
update distr6 dependency
RaphaelS1 Oct 8, 2023
0ae56d0
fix 2 small bugs
bblodfon Oct 13, 2023
707482e
update tests (combining different prediction types)
bblodfon Oct 13, 2023
6b01f57
combining survival matrices with arrays is now supported
bblodfon Oct 15, 2023
36e6b90
pump distr6
bblodfon Oct 15, 2023
762d807
correct distr6 PR
bblodfon Oct 15, 2023
a4b9d88
Update DESCRIPTION
RaphaelS1 Oct 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions R/integrated_scores.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@ weighted_survival_score = function(loss, truth, distribution, times, t_max, p_ma
unique_times = .c_get_unique_times(truth[, "time"], times)
}

is_distr_or_array =
(inherits(distribution, "Distribution")) ||
(inherits(distribution, "array") & length(dim(array)) == 3)

if (is_distr_or_array) {
# get the cdf
if (inherits(distribution, "Distribution")) {
cdf = as.matrix(distribution$cdf(unique_times))
} else {
}
else if (inherits(distribution, "array") &
length(dim(distribution)) == 3) {
# 'distribution' is a survival 3d array so create an
# `Arrdist` using the 'median' curve
arrdistr = distr6::as.Distribution(1 - distribution, fun = "cdf",
decorators = c("CoreStatistics", "ExoticStatistics"))
cdf = as.matrix(arrdistr$cdf(unique_times))
} else { # 'distribution' is a survival 2d matrix
bblodfon marked this conversation as resolved.
Show resolved Hide resolved
mtc = findInterval(unique_times, as.numeric(colnames(distribution)))
cdf = 1 - t(distribution[, mtc])
if (any(mtc == 0)) {
Expand All @@ -41,7 +46,7 @@ weighted_survival_score = function(loss, truth, distribution, times, t_max, p_ma
rownames(cdf) = unique_times
}

true_times <- truth[, "time"]
true_times = truth[, "time"]

assert_numeric(true_times, any.missing = FALSE)
assert_numeric(unique_times, any.missing = FALSE)
Expand Down
Loading