From e6e01c6bc0c6728cde258b10ecfe9b0285fdbf93 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 31 Jul 2023 16:37:27 -0700 Subject: [PATCH 1/6] loopify step_lag() --- R/lag.R | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/R/lag.R b/R/lag.R index add18f776..9511c984e 100644 --- a/R/lag.R +++ b/R/lag.R @@ -107,29 +107,28 @@ prep.step_lag <- function(x, training, info = NULL, ...) { #' @export bake.step_lag <- function(object, new_data, ...) { - check_new_data(names(object$columns), object, new_data) - - if (!all(object$lag == as.integer(object$lag))) { + if (!all(x$lag == as.integer(x$lag))) { rlang::abort("step_lag requires 'lag' argument to be integer valued.") } - make_call <- function(col, lag_val) { - call2( - "lag", - x = sym(col), - n = lag_val, - default = object$default, - .ns = "dplyr" + col_names <- names(object$columns) + check_new_data(col_names, object, new_data) + + for (col_name in col_names) { + new_values <- map( + object$lag, + function(x) dplyr::lag(new_data[[col_name]], x, default = object$default) ) - } - grid <- tidyr::expand_grid(col = object$columns, lag_val = object$lag) - calls <- purrr::map2(grid$col, grid$lag_val, make_call) - newname <- as.character(glue("{object$prefix}{grid$lag_val}_{grid$col}")) - calls <- check_name(calls, new_data, object, newname, TRUE) + new_names <- glue("{object$prefix}{object$lag}_{col_name}") + names(new_values) <- new_names + + new_values <- tibble::new_tibble(new_values) + new_values <- check_name(new_values, new_data, object, new_names) + new_data <- vec_cbind(new_data, new_values) + } - new_data <- mutate(new_data, !!!calls) - new_data <- remove_original_cols(new_data, object, names(object$columns)) + new_data <- remove_original_cols(new_data, object, col_names) new_data } From 95a8d3ffb64a210b70c638336f8a9f819225fc65 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 31 Jul 2023 16:37:53 -0700 Subject: [PATCH 2/6] move lag abort() into prep.step_lag() --- R/lag.R | 8 ++++---- tests/testthat/_snaps/lag.md | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/lag.R b/R/lag.R index 9511c984e..b7e37f397 100644 --- a/R/lag.R +++ b/R/lag.R @@ -91,6 +91,10 @@ step_lag_new <- #' @export prep.step_lag <- function(x, training, info = NULL, ...) { + if (!all(x$lag == as.integer(x$lag))) { + rlang::abort("step_lag requires 'lag' argument to be integer valued.") + } + step_lag_new( terms = x$terms, role = x$role, @@ -107,10 +111,6 @@ prep.step_lag <- function(x, training, info = NULL, ...) { #' @export bake.step_lag <- function(object, new_data, ...) { - if (!all(x$lag == as.integer(x$lag))) { - rlang::abort("step_lag requires 'lag' argument to be integer valued.") - } - col_names <- names(object$columns) check_new_data(col_names, object, new_data) diff --git a/tests/testthat/_snaps/lag.md b/tests/testthat/_snaps/lag.md index c1caba732..6dbfe0edf 100644 --- a/tests/testthat/_snaps/lag.md +++ b/tests/testthat/_snaps/lag.md @@ -4,7 +4,7 @@ prepped_rec <- recipe(~., data = df) %>% step_lag(x, lag = 0.5) %>% prep(df) Condition Error in `step_lag()`: - Caused by error in `bake()`: + Caused by error in `prep()`: ! step_lag requires 'lag' argument to be integer valued. # empty printing From b5d4fd1bfaaaf5aec0d7d4c3a24358d547a5f8ce Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 1 Aug 2023 12:32:15 -0700 Subject: [PATCH 3/6] namespace vec_cbind() --- R/lag.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/lag.R b/R/lag.R index b7e37f397..1d5cc17db 100644 --- a/R/lag.R +++ b/R/lag.R @@ -125,7 +125,7 @@ bake.step_lag <- function(object, new_data, ...) { new_values <- tibble::new_tibble(new_values) new_values <- check_name(new_values, new_data, object, new_names) - new_data <- vec_cbind(new_data, new_values) + new_data <- vctrs::vec_cbind(new_data, new_values) } new_data <- remove_original_cols(new_data, object, col_names) From 835ce075591bd9a84226adfd036db6fb05c80d46 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 1 Aug 2023 12:36:36 -0700 Subject: [PATCH 4/6] use lapply instead of map --- R/lag.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/lag.R b/R/lag.R index 1d5cc17db..667f09c89 100644 --- a/R/lag.R +++ b/R/lag.R @@ -115,7 +115,7 @@ bake.step_lag <- function(object, new_data, ...) { check_new_data(col_names, object, new_data) for (col_name in col_names) { - new_values <- map( + new_values <- lapply( object$lag, function(x) dplyr::lag(new_data[[col_name]], x, default = object$default) ) From 95dd40a0d6a72cf6d9f519aec6adc3643e43c1bc Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 1 Aug 2023 16:10:36 -0700 Subject: [PATCH 5/6] Apply suggestions from code review Co-authored-by: Simon P. Couch --- R/lag.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/lag.R b/R/lag.R index 667f09c89..9ee8910c8 100644 --- a/R/lag.R +++ b/R/lag.R @@ -92,7 +92,7 @@ step_lag_new <- #' @export prep.step_lag <- function(x, training, info = NULL, ...) { if (!all(x$lag == as.integer(x$lag))) { - rlang::abort("step_lag requires 'lag' argument to be integer valued.") + rlang::abort("step_lag() requires 'lag' argument to be integer-valued.") } step_lag_new( @@ -120,7 +120,7 @@ bake.step_lag <- function(object, new_data, ...) { function(x) dplyr::lag(new_data[[col_name]], x, default = object$default) ) - new_names <- glue("{object$prefix}{object$lag}_{col_name}") + new_names <- glue::glue("{object$prefix}{object$lag}_{col_name}") names(new_values) <- new_names new_values <- tibble::new_tibble(new_values) From 393d606b2c4fadac0efd22845d76bca23bd99c94 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 1 Aug 2023 16:12:13 -0700 Subject: [PATCH 6/6] rerender snapshot --- tests/testthat/_snaps/lag.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/_snaps/lag.md b/tests/testthat/_snaps/lag.md index 6dbfe0edf..1215690b6 100644 --- a/tests/testthat/_snaps/lag.md +++ b/tests/testthat/_snaps/lag.md @@ -5,7 +5,7 @@ Condition Error in `step_lag()`: Caused by error in `prep()`: - ! step_lag requires 'lag' argument to be integer valued. + ! step_lag() requires 'lag' argument to be integer-valued. # empty printing