Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add helper for bridging causal fits #955

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.1.0.9000
Version: 1.1.0.9001
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ S3method(update,svm_rbf)
S3method(varying_args,model_spec)
S3method(varying_args,recipe)
S3method(varying_args,step)
S3method(weight_propensity,default)
S3method(weight_propensity,model_fit)
S3method(weight_propensity,model_spec)
export("%>%")
export(.censoring_weights_graf)
export(.check_glmnet_penalty_fit)
Expand Down Expand Up @@ -315,6 +318,7 @@ export(update_model_info_file)
export(update_spec)
export(varying)
export(varying_args)
export(weight_propensity)
export(xgb_predict)
export(xgb_train)
importFrom(dplyr,arrange)
Expand Down
90 changes: 90 additions & 0 deletions R/weight_propensity.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#' Helper for bridging two-stage causal fits
#'
#' @description
#' `weight_propensity()` is a helper function to more easily link the
#' propensity and outcome models in causal workflows. **The main documentation
#' for this function lives in the tune package at** `?tune::weight_propensity`.
#'
#' @param object The object containing the model fit(s) that will generate
#' predictions used to calculate propensity weights. Currently, either a
#' [parsnip model fit][parsnip::fit.model_spec()], fitted
#' [workflow][workflows::workflow()], or
#' tuning results (`?tune::fit_resamples`) object. If a tuning result, the
#' object must have been generated with the control argument
#' (`?tune::control_resamples`) `extract = identity`.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option would be to instead require save_pred = TRUE, but we couldn't make use of weight_propensity.workflow in that case. This approach is a bit more DRY.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like a little bit of a rough edge to me. I'm not sure we need to sand over it right now in terms of the interface but I would add more documentation, especially on the "main" doc page in tune, which currently only mentions this in an example. What about adding a sentence to the Details section, explaining why this needs to be set like this? I think that would help people remember.

#' @param wt_fn A function used to calculate the propensity weights. The first
#' argument gives the predicted probability of exposure, the true value for
#' which is provided in the second argument. See `?propensity::wt_ate()` for
#' an example.
Comment on lines +15 to +18
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't find that second sentence the easiest to read with the "the true value for which". Is the following correct?

Suggested change
#' @param wt_fn A function used to calculate the propensity weights. The first
#' argument gives the predicted probability of exposure, the true value for
#' which is provided in the second argument. See `?propensity::wt_ate()` for
#' an example.
#' @param wt_fn A function used to calculate the propensity weights. The first
#' argument gives the predicted probability of exposure, the second argument
#' gives the true value of exposure. See `?propensity::wt_ate()` for
#' an example.

#' @param .treated The level of the exposure corresponding to the treatment, as
#' a string. Additionally passed as `.treated` to `wt_fn`.
#' @param ... Additional arguments passed to `wt_fn`.
#' @param data The data supplied as the `data` argument to `fit()` the `object`.
#' This argument is only required for the `model_fit` and `workflow` methods---the
#' needed data for the `tune_results` method lives inside of `object`.
#'
#' @return
#' For `model_fit` and fitted `workflow` input, a modified version of the data
#' set supplied in `data` that contains a `.wts` column with class
#' `importance_weights`. For `tune_results` input, a modified version of the
#' resampling object underlying the tuning results containing a new `.wts` column
#' with propensity values corresponding to each element of the analysis set.
#'
#' @references Barrett M & D'Agostino McGowan L (forthcoming).
#' _Causal Inference in R_. \url{https://www.r-causal.org/}
#' @name weight_propensity.model_fit
NULL

#' @rdname weight_propensity.model_fit
#' @export
weight_propensity <- function(object, wt_fn, ...) {
UseMethod("weight_propensity")
}

#' @rdname weight_propensity.model_fit
#' @method weight_propensity default
#' @export
weight_propensity.default <- function(object, wt_fn, ...) {
abort("No known `weight_propensity()` method for this type of object.")
}

#' @noRd
#' @method weight_propensity model_spec
#' @export
weight_propensity.model_spec <- function(object, wt_fn, ...) {
abort(c(
"`weight_propensity()` is not well-defined for a model specification.",
"i" = "Supply `object` to `fit()` before generating propensity weights."
))
}

#' @rdname weight_propensity.model_fit
#' @method weight_propensity model_fit
#' @export
weight_propensity.model_fit <- function(object,
wt_fn,
.treated = object$lvl[2],
Copy link
Contributor Author

@simonpcouch simonpcouch Apr 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to pass .treated explicitly because the predicted probabilities depend on the treatment level and must match the argument supplied as .treated to propensity. This argument is roughly analogous to our event_level argument, where the event level is either "first" or "second" (rather than the actual level of the factor). Propensity not only parameterizes the argument differently, but the default is the second level, while ours is the first.

I've tried out a few different interfaces to this argument and don't feel strongly on how we can best handle this. We could alternatively add an event_level argument with the usual parameterization and then translate it to the right .treated level when interfacing with propensity. We could also ask that propensity changes the default/parameterization, though there is some information loss in the multi-level setting, and we need to translate "first"/"second" back to the level anyway at predict() (see L81-82).

Note that the current form of that argument is not checked / tested, pending a decision on how we want it to feel.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's worth discussing with Lucy and Malcolm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucyMcGowan and @malcolmbarrett:

We're currently working on a set of PRs to better accommodate causal workflows in tidymodels via a helper, weight_propensity(), that bridges the propensity and outcome models during resampling. Some description of the bigger picture is here. The interface feels something like this:

  fit_resamples(
    propensity_workflow,
    resamples = bootstraps(data),
    control = control_resample(extract = identity)
  ) %>%
  weight_propensity(wt_ate, ...) %>%
  fit_resamples(
    outcome_workflow,
    resamples = .
  )

where the second argument to weight_propensity() is a propensity weighting function and further arguments are passed to that function—the helper handles the arguments .propensity and .exposure internally. The result of weight_propensity() in the above case is what the output of bootstraps(data %>% mutate(.wts = wt_ate(...))) would look like, where .propensity and .exposure are handled internally for the user.

There's surely lots to digest here, but do you have opinions on how we should open up the interface to the .treated argument? Feel free to give me a holler if you'd appreciate additional context. :)

Copy link
Contributor

@malcolmbarrett malcolmbarrett May 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is all awesome!

A few comments:

  1. we noticed that in propensity we also use the .treated terminology but now think that's a poor idea because not everything is a treatment. So, we're going to change that to .exposed, and I think it should probably be that here, too.
  2. One issue with that language and with that default value for what is currently .treated is that it only applies to binary and multiclass variables. For continuous variables, there won't be a default level. If this really becomes an issue to make it fit nicely, we're actually moving away from PS models for continuous exposures because of some mathematical issues with them.
  3. As for the default value, we do pick the second level because we assume 0 is unexposed and 1 is exposed, but that's mostly to do with what a common logistic model spec looks like. I'm going back to propensity soon and would be happy to work with you all to make this all consistent.

...,
data) {
if (rlang::is_missing(wt_fn) || !is.function(wt_fn)) {
abort("`wt_fn` must be a function.")
}

if (rlang::is_missing(data) || !is.data.frame(data)) {
abort("`data` must be the data supplied as the data argument to `fit()`.")
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only checks this PR makes on inputted data is that it's a data frame. This gives a window for folks to supply different data to fit() and weight_propensity(), probably resulting in uninformative errors at weight_propensity(). We could implement some functionality to "fingerprint" the training data, noting dims/column names or some other coarse set of identifying features, as a full hash would be too expensive to justify. Note that this is not an issue for the tune_results method (and its use of the workflow method), which accesses the training data via the underlying split and does not take a data argument.


# TODO: I'm not sure we have a way to identify `y` via a model
# spec fitted with `fit_xy()`---this will error in that case.
Comment on lines +77 to +78
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we started out with the "censored regression" mode, we required models to be fit via the formula interface, i.e., fit(), and had fit_xy() throw an error saying to use fit() for that mode.

In that spirit, you could add an error here to point people towards fit().

outcome_name <- object$preproc$y_var

preds <- predict(object, data, type = "prob")
preds <- preds[[paste0(".pred_", .treated)]]

data$.wts <-
hardhat::importance_weights(
wt_fn(preds, data[[outcome_name]], .treated = .treated, ...)
)

data
}
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ reference:
- translate
- starts_with("update")
- matches("_train")
- starts_with("weight_propensity")

- title: Developer tools
contents:
Expand Down
53 changes: 53 additions & 0 deletions man/weight_propensity.model_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 59 additions & 0 deletions tests/testthat/_snaps/weight_propensity.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# errors informatively with bad input

Code
weight_propensity(spec, silly_wt_fn, data = two_class_dat)
Condition
Error in `weight_propensity()`:
! `weight_propensity()` is not well-defined for a model specification.
i Supply `object` to `fit()` before generating propensity weights.

---

Code
weight_propensity("boop", silly_wt_fn, data = two_class_dat)
Condition
Error in `weight_propensity()`:
! No known `weight_propensity()` method for this type of object.

---

Code
weight_propensity(spec_fit, two_class_dat)
Condition
Error in `weight_propensity()`:
! `wt_fn` must be a function.

---

Code
weight_propensity(spec_fit, "boop", data = two_class_dat)
Condition
Error in `weight_propensity()`:
! `wt_fn` must be a function.

---

Code
weight_propensity(spec_fit, function(...) {
-1L
}, data = two_class_dat)
Condition
Error in `hardhat::importance_weights()`:
! `x` can't contain negative weights.

---

Code
weight_propensity(spec_fit, silly_wt_fn)
Condition
Error in `weight_propensity()`:
! `data` must be the data supplied as the data argument to `fit()`.

---

Code
weight_propensity(spec_fit, silly_wt_fn, data = "boop")
Condition
Error in `weight_propensity()`:
! `data` must be the data supplied as the data argument to `fit()`.

68 changes: 68 additions & 0 deletions tests/testthat/test-weight_propensity.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
test_that("basic functionality", {
skip_if_not_installed("modeldata")
library(modeldata)
library(parsnip)

silly_wt_fn <- function(.propensity, .exposure = NULL, ...) {
seq(1, 2, length.out = length(.propensity))
}

lr_fit <- fit(logistic_reg(), Class ~ A + B, two_class_dat)

lr_res1 <- weight_propensity(lr_fit, silly_wt_fn, data = two_class_dat)
expect_s3_class(lr_res1, "tbl_df")
expect_true(all(names(lr_res1) %in% c(names(two_class_dat), ".wts")))
expect_equal(lr_res1$.wts, importance_weights(seq(1, 2, length.out = nrow(two_class_dat))))
})

test_that("errors informatively with bad input", {
skip_if_not_installed("modeldata")
library(modeldata)
library(parsnip)

silly_wt_fn <- function(.propensity, .exposure = NULL, ...) {
seq(1, 2, length.out = length(.propensity))
}

# bad `object`
spec <- logistic_reg()

expect_snapshot(
error = TRUE,
weight_propensity(spec, silly_wt_fn, data = two_class_dat)
)

expect_snapshot(
error = TRUE,
weight_propensity("boop", silly_wt_fn, data = two_class_dat)
)

# bad `wt_fn`
spec_fit <- fit(spec, Class ~ A + B, data = two_class_dat)

expect_snapshot(
error = TRUE,
weight_propensity(spec_fit, two_class_dat)
)

expect_snapshot(
error = TRUE,
weight_propensity(spec_fit, "boop", data = two_class_dat)
)

expect_snapshot(
error = TRUE,
weight_propensity(spec_fit, function(...) {-1L}, data = two_class_dat)
)

# bad `data`
expect_snapshot(
error = TRUE,
weight_propensity(spec_fit, silly_wt_fn)
)

expect_snapshot(
error = TRUE,
weight_propensity(spec_fit, silly_wt_fn, data = "boop")
)
})