Skip to content

Commit

Permalink
Merge pull request #1179 from tidymodels/loopify-lag
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt authored Aug 2, 2023
2 parents f5f1aa7 + 393d606 commit b7c6419
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
37 changes: 18 additions & 19 deletions R/lag.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ step_lag_new <-

#' @export
prep.step_lag <- function(x, training, info = NULL, ...) {
if (!all(x$lag == as.integer(x$lag))) {
rlang::abort("step_lag() requires 'lag' argument to be integer-valued.")
}

step_lag_new(
terms = x$terms,
role = x$role,
Expand All @@ -107,29 +111,24 @@ prep.step_lag <- function(x, training, info = NULL, ...) {

#' @export
bake.step_lag <- function(object, new_data, ...) {
check_new_data(names(object$columns), object, new_data)

if (!all(object$lag == as.integer(object$lag))) {
rlang::abort("step_lag requires 'lag' argument to be integer valued.")
}
col_names <- names(object$columns)
check_new_data(col_names, object, new_data)

make_call <- function(col, lag_val) {
call2(
"lag",
x = sym(col),
n = lag_val,
default = object$default,
.ns = "dplyr"
for (col_name in col_names) {
new_values <- lapply(
object$lag,
function(x) dplyr::lag(new_data[[col_name]], x, default = object$default)
)
}

grid <- tidyr::expand_grid(col = object$columns, lag_val = object$lag)
calls <- purrr::map2(grid$col, grid$lag_val, make_call)
newname <- as.character(glue("{object$prefix}{grid$lag_val}_{grid$col}"))
calls <- check_name(calls, new_data, object, newname, TRUE)
new_names <- glue::glue("{object$prefix}{object$lag}_{col_name}")
names(new_values) <- new_names

new_values <- tibble::new_tibble(new_values)
new_values <- check_name(new_values, new_data, object, new_names)
new_data <- vctrs::vec_cbind(new_data, new_values)
}

new_data <- mutate(new_data, !!!calls)
new_data <- remove_original_cols(new_data, object, names(object$columns))
new_data <- remove_original_cols(new_data, object, col_names)
new_data
}

Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/lag.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
prepped_rec <- recipe(~., data = df) %>% step_lag(x, lag = 0.5) %>% prep(df)
Condition
Error in `step_lag()`:
Caused by error in `bake()`:
! step_lag requires 'lag' argument to be integer valued.
Caused by error in `prep()`:
! step_lag() requires 'lag' argument to be integer-valued.

# empty printing

Expand Down

0 comments on commit b7c6419

Please sign in to comment.