Skip to content

Commit

Permalink
loopify step_bs()
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Jul 31, 2023
1 parent 49bd773 commit 08951bc
Showing 1 changed file with 12 additions and 21 deletions.
33 changes: 12 additions & 21 deletions R/bs.R
Original file line number Diff line number Diff line change
Expand Up @@ -168,29 +168,20 @@ prep.step_bs <- function(x, training, info = NULL, ...) {

#' @export
bake.step_bs <- function(object, new_data, ...) {
check_new_data(names(object$objects), object, new_data)

## pre-allocate a matrix for the basis functions.
new_cols <- vapply(object$objects, ncol, c(int = 1L))
bs_values <-
matrix(NA, nrow = nrow(new_data), ncol = sum(new_cols))
colnames(bs_values) <- rep("", sum(new_cols))
strt <- 1
for (i in names(object$objects)) {
cols <- (strt):(strt + new_cols[i] - 1)
orig_var <- attr(object$objects[[i]], "var")
bs_values[, cols] <-
bs_predict(object$objects[[i]], new_data[[i]])
new_names <-
paste(orig_var, "bs", names0(new_cols[i], ""), sep = "_")
colnames(bs_values)[cols] <- new_names
strt <- max(cols) + 1
new_data <- remove_original_cols(new_data, object, orig_var)
col_names <- names(object$objects)
check_new_data(col_names, object, new_data)

for (col_name in col_names) {
new_values <- bs_predict(object$objects[[col_name]], new_data[[col_name]])
new_values <- as_tibble(new_values)
new_names <- paste(col_name, "bs", names0(ncol(new_values), ""), sep = "_")

colnames(new_values) <- new_names
new_values <- check_name(new_values, new_data, object, new_names)
new_data <- vec_cbind(new_data, new_values)
}
bs_values <- as_tibble(bs_values)
bs_values <- check_name(bs_values, new_data, object, names(bs_values))

new_data <- vec_cbind(new_data, bs_values)
new_data <- remove_original_cols(new_data, object, col_names)
new_data
}

Expand Down

0 comments on commit 08951bc

Please sign in to comment.