From 0365498df5d8413cc5f4a24650f0ef5ef0e99492 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 28 Apr 2023 16:05:34 -0400 Subject: [PATCH 1/3] add helper for bridging causal fits --- DESCRIPTION | 2 +- NAMESPACE | 4 + R/weight_propensity.R | 90 ++++++++++++++++++++++ man/weight_propensity.model_fit.Rd | 53 +++++++++++++ tests/testthat/_snaps/weight_propensity.md | 59 ++++++++++++++ tests/testthat/test-weight_propensity.R | 68 ++++++++++++++++ 6 files changed, 275 insertions(+), 1 deletion(-) create mode 100644 R/weight_propensity.R create mode 100644 man/weight_propensity.model_fit.Rd create mode 100644 tests/testthat/_snaps/weight_propensity.md create mode 100644 tests/testthat/test-weight_propensity.R diff --git a/DESCRIPTION b/DESCRIPTION index bca85c1b9..32b84c61e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip Title: A Common API to Modeling and Analysis Functions -Version: 1.1.0.9000 +Version: 1.1.0.9001 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), person("Davis", "Vaughan", , "davis@posit.co", role = "aut"), diff --git a/NAMESPACE b/NAMESPACE index 63b4305b1..9eebfe33b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -145,6 +145,9 @@ S3method(update,svm_rbf) S3method(varying_args,model_spec) S3method(varying_args,recipe) S3method(varying_args,step) +S3method(weight_propensity,default) +S3method(weight_propensity,model_fit) +S3method(weight_propensity,model_spec) export("%>%") export(.censoring_weights_graf) export(.check_glmnet_penalty_fit) @@ -315,6 +318,7 @@ export(update_model_info_file) export(update_spec) export(varying) export(varying_args) +export(weight_propensity) export(xgb_predict) export(xgb_train) importFrom(dplyr,arrange) diff --git a/R/weight_propensity.R b/R/weight_propensity.R new file mode 100644 index 000000000..65ce83cf6 --- /dev/null +++ b/R/weight_propensity.R @@ -0,0 +1,90 @@ +#' Helper for bridging two-stage causal fits +#' +#' @description +#' `weight_propensity()` is a helper function to more easily link the +#' propensity and outcome models in causal workflows. **The main documentation +#' for this function lives in the tune package at** `?tune::weight_propensity`. +#' +#' @param object The object containing the model fit(s) that will generate +#' predictions used to calculate propensity weights. Currently, either a +#' [parsnip model fit][parsnip::fit.model_spec()], fitted +#' [workflow][workflows::workflow()], or +#' tuning results (`?tune::fit_resamples`) object. If a tuning result, the +#' object must have been generated with the control argument +#' (`?tune::control_resamples`) `extract = identity`. +#' @param wt_fn A function used to calculate the propensity weights. The first +#' argument gives the predicted probability of exposure, the true value for +#' which is provided in the second argument. See `?propensity::wt_ate()` for +#' an example. +#' @param .treated The level of the exposure corresponding to the treatment, as +#' a string. Additionally passed as `.treated` to `wt_fn`. +#' @param ... Additional arguments passed to `wt_fn`. +#' @param data The data supplied as the `data` argument to `fit()` the `object`. +#' This argument is only required for the `model_fit` and `workflow` methods---the +#' needed data for the `tune_results` method lives inside of `object`. +#' +#' @return +#' For `model_fit` and fitted `workflow` input, a modified version of the data +#' set supplied in `data` that contains a `.wts` column with class +#' `importance_weights`. For `tune_results` input, a modified version of the +#' resampling object underlying the tuning results containing a new `.wts` column +#' with propensity values corresponding to each element of the analysis set. +#' +#' @references Barrett M & D'Agostino McGowan L (forthcoming). +#' _Causal Inference in R_. \url{https://www.r-causal.org/} +#' @name weight_propensity.model_fit +NULL + +#' @rdname weight_propensity.model_fit +#' @export +weight_propensity <- function(object, wt_fn, ...) { + UseMethod("weight_propensity") +} + +#' @rdname weight_propensity.model_fit +#' @method weight_propensity default +#' @export +weight_propensity.default <- function(object, wt_fn, ...) { + abort("No known `weight_propensity()` method for this type of object.") +} + +#' @noRd +#' @method weight_propensity model_spec +#' @export +weight_propensity.model_spec <- function(object, wt_fn, ...) { + abort(c( + "`weight_propensity()` is not well-defined for a model specification.", + "i" = "Supply `object` to `fit()` before generating propensity weights." + )) +} + +#' @rdname weight_propensity.model_fit +#' @method weight_propensity model_fit +#' @export +weight_propensity.model_fit <- function(object, + wt_fn, + .treated = object$lvl[2], + ..., + data) { + if (rlang::is_missing(wt_fn) || !is.function(wt_fn)) { + abort("`wt_fn` must be a function.") + } + + if (rlang::is_missing(data) || !is.data.frame(data)) { + abort("`data` must be the data supplied as the data argument to `fit()`.") + } + + # TODO: I'm not sure we have a way to identify `y` via a model + # spec fitted with `fit_xy()`---this will error in that case. + outcome_name <- object$preproc$y_var + + preds <- predict(object, data, type = "prob") + preds <- preds[[paste0(".pred_", .treated)]] + + data$.wts <- + hardhat::importance_weights( + wt_fn(preds, data[[outcome_name]], .treated, ...) + ) + + data +} diff --git a/man/weight_propensity.model_fit.Rd b/man/weight_propensity.model_fit.Rd new file mode 100644 index 000000000..4154bb217 --- /dev/null +++ b/man/weight_propensity.model_fit.Rd @@ -0,0 +1,53 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/weight_propensity.R +\name{weight_propensity.model_fit} +\alias{weight_propensity.model_fit} +\alias{weight_propensity} +\alias{weight_propensity.default} +\title{Helper for bridging two-stage causal fits} +\usage{ +weight_propensity(object, wt_fn, ...) + +\method{weight_propensity}{default}(object, wt_fn, ...) + +\method{weight_propensity}{model_fit}(object, wt_fn, .treated = object$lvl[2], ..., data) +} +\arguments{ +\item{object}{The object containing the model fit(s) that will generate +predictions used to calculate propensity weights. Currently, either a +\link[=fit.model_spec]{parsnip model fit}, fitted +\link[workflows:workflow]{workflow}, or +tuning results (\code{?tune::fit_resamples}) object. If a tuning result, the +object must have been generated with the control argument +(\code{?tune::control_resamples}) \code{extract = identity}.} + +\item{wt_fn}{A function used to calculate the propensity weights. The first +argument gives the predicted probability of exposure, the true value for +which is provided in the second argument. See \code{?propensity::wt_ate()} for +an example.} + +\item{...}{Additional arguments passed to \code{wt_fn}.} + +\item{.treated}{The level of the exposure corresponding to the treatment, as +a string. Additionally passed as \code{.treated} to \code{wt_fn}.} + +\item{data}{The data supplied as the \code{data} argument to \code{fit()} the \code{object}. +This argument is only required for the \code{model_fit} and \code{workflow} methods---the +needed data for the \code{tune_results} method lives inside of \code{object}.} +} +\value{ +For \code{model_fit} and fitted \code{workflow} input, a modified version of the data +set supplied in \code{data} that contains a \code{.wts} column with class +\code{importance_weights}. For \code{tune_results} input, a modified version of the +resampling object underlying the tuning results containing a new \code{.wts} column +with propensity values corresponding to each element of the analysis set. +} +\description{ +\code{weight_propensity()} is a helper function to more easily link the +propensity and outcome models in causal workflows. \strong{The main documentation +for this function lives in the tune package at} \code{?tune::weight_propensity}. +} +\references{ +Barrett M & D'Agostino McGowan L (forthcoming). +\emph{Causal Inference in R}. \url{https://www.r-causal.org/} +} diff --git a/tests/testthat/_snaps/weight_propensity.md b/tests/testthat/_snaps/weight_propensity.md new file mode 100644 index 000000000..d6631bd11 --- /dev/null +++ b/tests/testthat/_snaps/weight_propensity.md @@ -0,0 +1,59 @@ +# errors informatively with bad input + + Code + weight_propensity(spec, silly_wt_fn, data = two_class_dat) + Condition + Error in `weight_propensity()`: + ! `weight_propensity()` is not well-defined for a model specification. + i Supply `object` to `fit()` before generating propensity weights. + +--- + + Code + weight_propensity("boop", silly_wt_fn, data = two_class_dat) + Condition + Error in `weight_propensity()`: + ! No known `weight_propensity()` method for this type of object. + +--- + + Code + weight_propensity(spec_fit, two_class_dat) + Condition + Error in `weight_propensity()`: + ! `wt_fn` must be a function. + +--- + + Code + weight_propensity(spec_fit, "boop", data = two_class_dat) + Condition + Error in `weight_propensity()`: + ! `wt_fn` must be a function. + +--- + + Code + weight_propensity(spec_fit, function(...) { + -1L + }, data = two_class_dat) + Condition + Error in `hardhat::importance_weights()`: + ! `x` can't contain negative weights. + +--- + + Code + weight_propensity(spec_fit, silly_wt_fn) + Condition + Error in `weight_propensity()`: + ! `data` must be the data supplied as the data argument to `fit()`. + +--- + + Code + weight_propensity(spec_fit, silly_wt_fn, data = "boop") + Condition + Error in `weight_propensity()`: + ! `data` must be the data supplied as the data argument to `fit()`. + diff --git a/tests/testthat/test-weight_propensity.R b/tests/testthat/test-weight_propensity.R new file mode 100644 index 000000000..9df49527c --- /dev/null +++ b/tests/testthat/test-weight_propensity.R @@ -0,0 +1,68 @@ +test_that("basic functionality", { + skip_if_not_installed("modeldata") + library(modeldata) + library(parsnip) + + silly_wt_fn <- function(.propensity, .exposure = NULL, ...) { + seq(1, 2, length.out = length(.propensity)) + } + + lr_fit <- fit(logistic_reg(), Class ~ A + B, two_class_dat) + + lr_res1 <- weight_propensity(lr_fit, silly_wt_fn, data = two_class_dat) + expect_s3_class(lr_res1, "tbl_df") + expect_true(all(names(lr_res1) %in% c(names(two_class_dat), ".wts"))) + expect_equal(lr_res1$.wts, importance_weights(seq(1, 2, length.out = nrow(two_class_dat)))) +}) + +test_that("errors informatively with bad input", { + skip_if_not_installed("modeldata") + library(modeldata) + library(parsnip) + + silly_wt_fn <- function(.propensity, .exposure = NULL, ...) { + seq(1, 2, length.out = length(.propensity)) + } + + # bad `object` + spec <- logistic_reg() + + expect_snapshot( + error = TRUE, + weight_propensity(spec, silly_wt_fn, data = two_class_dat) + ) + + expect_snapshot( + error = TRUE, + weight_propensity("boop", silly_wt_fn, data = two_class_dat) + ) + + # bad `wt_fn` + spec_fit <- fit(spec, Class ~ A + B, data = two_class_dat) + + expect_snapshot( + error = TRUE, + weight_propensity(spec_fit, two_class_dat) + ) + + expect_snapshot( + error = TRUE, + weight_propensity(spec_fit, "boop", data = two_class_dat) + ) + + expect_snapshot( + error = TRUE, + weight_propensity(spec_fit, function(...) {-1L}, data = two_class_dat) + ) + + # bad `data` + expect_snapshot( + error = TRUE, + weight_propensity(spec_fit, silly_wt_fn) + ) + + expect_snapshot( + error = TRUE, + weight_propensity(spec_fit, silly_wt_fn, data = "boop") + ) +}) From 540d5e8fd03f044b3bd84d186fc927b77fc18c5a Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 28 Apr 2023 16:16:25 -0400 Subject: [PATCH 2/3] name `.treated` --- R/weight_propensity.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/weight_propensity.R b/R/weight_propensity.R index 65ce83cf6..616f5afee 100644 --- a/R/weight_propensity.R +++ b/R/weight_propensity.R @@ -83,7 +83,7 @@ weight_propensity.model_fit <- function(object, data$.wts <- hardhat::importance_weights( - wt_fn(preds, data[[outcome_name]], .treated, ...) + wt_fn(preds, data[[outcome_name]], .treated = .treated, ...) ) data From dc7898d7bcd171b3f46e152951cf2792214a2340 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Fri, 28 Apr 2023 16:20:29 -0400 Subject: [PATCH 3/3] add pkgdown entry --- _pkgdown.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/_pkgdown.yml b/_pkgdown.yml index 12fe60196..b2f0b4358 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -92,6 +92,7 @@ reference: - translate - starts_with("update") - matches("_train") + - starts_with("weight_propensity") - title: Developer tools contents: