Skip to content

Commit

Permalink
loopify step_ns()
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Jul 31, 2023
1 parent 08951bc commit 92281d4
Showing 1 changed file with 13 additions and 20 deletions.
33 changes: 13 additions & 20 deletions R/ns.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,27 +159,20 @@ prep.step_ns <- function(x, training, info = NULL, ...) {

#' @export
bake.step_ns <- function(object, new_data, ...) {
check_new_data(names(object$objects), object, new_data)
## pre-allocate a matrix for the basis functions.
new_cols <- vapply(object$objects, ncol, c(int = 1L))
ns_values <-
matrix(NA, nrow = nrow(new_data), ncol = sum(new_cols))
colnames(ns_values) <- rep("", sum(new_cols))
strt <- 1
for (i in names(object$objects)) {
cols <- (strt):(strt + new_cols[i] - 1)
orig_var <- attr(object$objects[[i]], "var")
ns_values[, cols] <-
ns_predict(object$objects[[i]], new_data[[i]])
new_names <-
paste(orig_var, "ns", names0(new_cols[i], ""), sep = "_")
colnames(ns_values)[cols] <- new_names
strt <- max(cols) + 1
new_data <- remove_original_cols(new_data, object, orig_var)
col_names <- names(object$objects)
check_new_data(col_names, object, new_data)

for (col_name in col_names) {
new_values <- bs_predict(object$objects[[col_name]], new_data[[col_name]])
new_values <- as_tibble(new_values)
new_names <- paste(col_name, "ns", names0(ncol(new_values), ""), sep = "_")

colnames(new_values) <- new_names
new_values <- check_name(new_values, new_data, object, new_names)
new_data <- vec_cbind(new_data, new_values)
}
ns_values <- as_tibble(ns_values)
ns_values <- check_name(ns_values, new_data, object, names(ns_values))
new_data <- vec_cbind(new_data, ns_values)

new_data <- remove_original_cols(new_data, object, col_names)
new_data
}

Expand Down

0 comments on commit 92281d4

Please sign in to comment.