Skip to content

Commit

Permalink
update calibration methods to work better with non-standard names (#146)
Browse files Browse the repository at this point in the history
* changes for #145

* expr_name -> as_name
  • Loading branch information
topepo authored Jun 5, 2024
1 parent 29d9fb3 commit d413853
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 9 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

* A new function `bound_prediction()` is available to constrain the values of a numeric prediction (#142).

* Fixed a bug where non-standard names of class probability estimates resulted in an error for some calibration models (#145).

* Bug fix for `cal_plot_breaks()` with binary classification with custom probability column names (#144).

* Fixed an error in `int_conformal_cv()` when grouped resampling was used (#141).


# probably 1.0.3

* Fixed a bug where the grouping for calibration methods was sensitive to the type of the grouping variables (#127).
Expand Down
2 changes: 1 addition & 1 deletion R/cal-apply-impl.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ apply_adjustment <- function(new_data, object) {
}

if (object$type == "one_vs_all") {
ols <- as.character(object$levels)
ols <- purrr::map_chr(object$levels, rlang::as_name)
rs <- rowSums(new_data[, ols])
for (i in seq_along(ols)) {
new_data[, ols[i]] <- new_data[, ols[i]] / rs
Expand Down
6 changes: 3 additions & 3 deletions R/cal-apply-multi.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ apply_multi_predict <- function(object, .data) {
preds <- object$estimates[[1]]$estimate %>%
predict(newdata = .data, type = prob_type)

colnames(preds) <- as.character(object$levels)
colnames(preds) <- purrr::map_chr(object$levels, rlang::as_name)
preds <- dplyr::as_tibble(preds)

for (i in seq_along(object$levels)) {
lev <- object$levels[i]
.data[, as.character(lev)] <- preds[, as.character(lev)]
lev <- rlang::as_name(object$levels[[i]])
.data[, lev] <- preds[, lev]
}
.data
}
2 changes: 1 addition & 1 deletion R/cal-apply.R
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ cal_adjust_update <- function(.data,
res[, pred_name] <- NULL
}

col_names <- as.character(object$levels)
col_names <- purrr::map_chr(object$levels, rlang::as_name)
factor_levels <- names(object$levels)

predictions <- res[, col_names] %>%
Expand Down
8 changes: 6 additions & 2 deletions R/cal-estimate-beta.R
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ cal_beta_impl_grp <- function(.data,
estimate = NULL,
levels = NULL,
...) {

list_names <- purrr::map_chr(estimate, rlang::as_name)

.data %>%
split_dplyr_groups() %>%
lapply(
Expand All @@ -165,8 +168,9 @@ cal_beta_impl_grp <- function(.data,
location_params = location_params,
estimate = estimate,
...
) %>%
rlang::set_names(as.character(estimate))
)

names(estimate) <- list_names

list(
filter = x$filter,
Expand Down
7 changes: 5 additions & 2 deletions R/cal-estimate-isotonic.R
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ cal_isoreg_impl <- function(.data,
}

cal_isoreg_impl_grp <- function(.data, truth, estimate, sampled, ...) {
list_names <- purrr::map_chr(estimate, rlang::as_name)

.data %>%
split_dplyr_groups() %>%
lapply(
Expand All @@ -281,8 +283,9 @@ cal_isoreg_impl_grp <- function(.data, truth, estimate, sampled, ...) {
estimate = estimate,
sampled = sampled,
... = ...
) %>%
rlang::set_names(as.character(estimate))
)

names(iso_model) <- list_names

list(
filter = x$filter,
Expand Down
15 changes: 15 additions & 0 deletions tests/testthat/test-cal-estimate.R
Original file line number Diff line number Diff line change
Expand Up @@ -530,3 +530,18 @@ test_that("Test exceptions", {
cal_estimate_isotonic(segment_logistic, Class, dplyr::starts_with("bad"))
)
})

test_that("non-standard column names", {
# issue 145
seg <- segment_logistic %>%
rename_with(~ paste0(.x, "-1"), matches(".pred")) %>%
mutate(
Class = paste0(Class,"-1"),
Class = factor(Class),
.pred_class = ifelse(`.pred_poor-1` >= 0.5, "poor-1", "good-1")
)
calib <- cal_estimate_isotonic(seg, Class)
new_pred <- cal_apply(seg, calib, pred_class = .pred_class)
expect_named(new_pred, c(".pred_poor-1", ".pred_good-1", "Class", ".pred_class"))

})

0 comments on commit d413853

Please sign in to comment.